mirror of
https://github.com/anoma/juvix.git
synced 2025-01-08 16:51:53 +03:00
Bugfix: compiler looping with the specialize
pragma (#2899)
* Closes #2884
This commit is contained in:
parent
7d2a59cc9f
commit
5a76e5d9dc
@ -1 +1 @@
|
||||
Subproject commit 216cb609cbe5aec9badea858f151a5ea400f2e66
|
||||
Subproject commit 17f22fcec5d78be511ea59984aee3499da5f3342
|
@ -101,3 +101,6 @@ nonRecursiveIdents' tab =
|
||||
HashSet.difference
|
||||
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
|
||||
(recursiveIdentsClosure tab)
|
||||
|
||||
nonRecursiveIdents :: Module -> HashSet Symbol
|
||||
nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable
|
||||
|
@ -39,6 +39,7 @@ data TransformationId
|
||||
| SpecializeArgs
|
||||
| CaseFolding
|
||||
| CasePermutation
|
||||
| ConstantFolding
|
||||
| FilterUnreachable
|
||||
| OptPhaseEval
|
||||
| OptPhaseExec
|
||||
@ -113,6 +114,7 @@ instance TransformationId' TransformationId where
|
||||
SpecializeArgs -> strSpecializeArgs
|
||||
CaseFolding -> strCaseFolding
|
||||
CasePermutation -> strCasePermutation
|
||||
ConstantFolding -> strConstantFolding
|
||||
FilterUnreachable -> strFilterUnreachable
|
||||
OptPhaseEval -> strOptPhaseEval
|
||||
OptPhaseExec -> strOptPhaseExec
|
||||
|
@ -119,6 +119,9 @@ strCaseFolding = "case-folding"
|
||||
strCasePermutation :: Text
|
||||
strCasePermutation = "case-permutation"
|
||||
|
||||
strConstantFolding :: Text
|
||||
strConstantFolding = "constant-folding"
|
||||
|
||||
strFilterUnreachable :: Text
|
||||
strFilterUnreachable = "filter-unreachable"
|
||||
|
||||
|
@ -36,6 +36,7 @@ import Juvix.Compiler.Core.Transformation.Normalize
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.ConstantFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
@ -96,6 +97,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
|
||||
SpecializeArgs -> return . specializeArgs
|
||||
CaseFolding -> return . caseFolding
|
||||
CasePermutation -> return . casePermutation
|
||||
ConstantFolding -> constantFolding
|
||||
FilterUnreachable -> return . filterUnreachable
|
||||
OptPhaseEval -> Phase.Eval.optimize
|
||||
OptPhaseExec -> Phase.Exec.optimize
|
||||
|
@ -17,7 +17,7 @@ isInlineableLambda inlineDepth md bl node = case node of
|
||||
False
|
||||
|
||||
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
|
||||
convertNode inlineDepth recSyms md = dmapL go
|
||||
convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
where
|
||||
go :: BinderList Binder -> Node -> Node
|
||||
go bl node = case node of
|
||||
@ -37,7 +37,7 @@ convertNode inlineDepth recSyms md = dmapL go
|
||||
Just InlineNever ->
|
||||
node
|
||||
_
|
||||
| not (HashSet.member _identSymbol recSyms)
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& isInlineableLambda inlineDepth md bl def
|
||||
&& length args >= argsNum ->
|
||||
mkApps def args
|
||||
@ -57,7 +57,7 @@ convertNode inlineDepth recSyms md = dmapL go
|
||||
Just InlineAlways -> def
|
||||
Just InlineNever -> node
|
||||
_
|
||||
| not (HashSet.member _identSymbol recSyms)
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& isImmediate md def ->
|
||||
def
|
||||
| otherwise ->
|
||||
@ -76,7 +76,7 @@ convertNode inlineDepth recSyms md = dmapL go
|
||||
Just InlineCase ->
|
||||
NCase cs {_caseValue = mkApps def args}
|
||||
Nothing
|
||||
| not (HashSet.member _identSymbol recSyms)
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& isConstructorApp def
|
||||
&& checkDepth md bl inlineDepth def ->
|
||||
NCase cs {_caseValue = mkApps def args}
|
||||
@ -92,9 +92,9 @@ convertNode inlineDepth recSyms md = dmapL go
|
||||
node
|
||||
|
||||
inlining' :: Int -> HashSet Symbol -> Module -> Module
|
||||
inlining' inliningDepth recSyms md = mapT (const (convertNode inliningDepth recSyms md)) md
|
||||
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
|
||||
|
||||
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
|
||||
inlining md = do
|
||||
d <- asks (^. optInliningDepth)
|
||||
return $ inlining' d (recursiveIdents md) md
|
||||
return $ inlining' d (nonRecursiveIdents md) md
|
||||
|
@ -33,9 +33,6 @@ optimize' opts@CoreOptions {..} md =
|
||||
tab :: InfoTable
|
||||
tab = computeCombinedInfoTable md
|
||||
|
||||
recs :: HashSet Symbol
|
||||
recs = recursiveIdents' tab
|
||||
|
||||
nonRecs :: HashSet Symbol
|
||||
nonRecs = nonRecursiveIdents' tab
|
||||
|
||||
@ -48,12 +45,12 @@ optimize' opts@CoreOptions {..} md =
|
||||
| otherwise = nonRecs
|
||||
|
||||
doInlining :: Module -> Module
|
||||
doInlining md' = inlining' _optInliningDepth recs' md'
|
||||
doInlining md' = inlining' _optInliningDepth nonRecs' md'
|
||||
where
|
||||
recs' =
|
||||
nonRecs' =
|
||||
if
|
||||
| _optOptimizationLevel > 1 -> recursiveIdents md'
|
||||
| otherwise -> recs
|
||||
| _optOptimizationLevel > 1 -> nonRecursiveIdents md'
|
||||
| otherwise -> nonRecs
|
||||
|
||||
doSimplification :: Int -> Module -> Module
|
||||
doSimplification n =
|
||||
|
@ -9,20 +9,12 @@ mymap {A B} (f : A -> B) : List A -> List B
|
||||
| (x :: xs) := f x :: mymap f xs;
|
||||
|
||||
{-# specialize: [2, 5], inline: false #-}
|
||||
myf
|
||||
: {A B : Type}
|
||||
-> A
|
||||
-> (A -> A -> B)
|
||||
-> A
|
||||
-> B
|
||||
-> Bool
|
||||
-> B
|
||||
myf : {A B : Type} -> A -> (A -> A -> B) -> A -> B -> Bool -> B
|
||||
| a0 f a b true := f a0 a
|
||||
| a0 f a b false := b;
|
||||
|
||||
{-# inline: false #-}
|
||||
myf'
|
||||
: {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
|
||||
myf' : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
|
||||
| a0 f a b := myf a0 (f a0) a b true;
|
||||
|
||||
sum : List Nat -> Nat
|
||||
@ -40,8 +32,7 @@ funa : {A : Type} -> (A -> A) -> A -> A
|
||||
{-# specialize: true #-}
|
||||
type Additive A := mkAdditive {add : A -> A -> A};
|
||||
|
||||
type Multiplicative A :=
|
||||
mkMultiplicative {mul : A -> A -> A};
|
||||
type Multiplicative A := mkMultiplicative {mul : A -> A -> A};
|
||||
|
||||
addNat : Additive Nat := mkAdditive (+);
|
||||
|
||||
@ -49,20 +40,20 @@ addNat : Additive Nat := mkAdditive (+);
|
||||
mulNat : Multiplicative Nat := mkMultiplicative (*);
|
||||
|
||||
{-# inline: false #-}
|
||||
fadd {A} (a : Additive A) (x y : A) : A :=
|
||||
Additive.add a x y;
|
||||
fadd {A} (a : Additive A) (x y : A) : A := Additive.add a x y;
|
||||
|
||||
{-# inline: false #-}
|
||||
fmul {A} (m : Multiplicative A) (x y : A) : A :=
|
||||
Multiplicative.mul m x y;
|
||||
fmul {A} (m : Multiplicative A) (x y : A) : A := Multiplicative.mul m x y;
|
||||
|
||||
{-# specialize: [1] #-}
|
||||
myfilter {A} (f : A → Bool) : List A → List A
|
||||
| nil := nil
|
||||
| (h :: hs) := ite (f h) (h :: myfilter f hs) (myfilter f hs);
|
||||
|
||||
main : Nat :=
|
||||
sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
|
||||
+ sum
|
||||
(flatten
|
||||
(mymap
|
||||
(mymap λ {x := x + 2})
|
||||
((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
|
||||
sum (myfilter (const false) [])
|
||||
+ sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
|
||||
+ sum (flatten (mymap (mymap λ {x := x + 2}) ((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
|
||||
+ myf 3 (*) 2 5 true
|
||||
+ myf 1 (+) 2 0 false
|
||||
+ myf' 7 (const (+)) 2 0
|
||||
|
Loading…
Reference in New Issue
Block a user