mirror of
https://github.com/anoma/juvix.git
synced 2024-11-30 14:13:27 +03:00
Arithmetic simplification (#2454)
Simplifies arithmetic expressions in the Core optimization phase, changing e.g. `(x - 1) + 1` to `x`. Such expressions appear as a result of compiling pattern matching on natural numbers.
This commit is contained in:
parent
c1c2a06922
commit
cdfb35aaac
@ -10,6 +10,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyArithmetic
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs
|
||||
@ -53,7 +54,8 @@ optimize' CoreOptions {..} tab =
|
||||
|
||||
doSimplification :: Int -> InfoTable -> InfoTable
|
||||
doSimplification n =
|
||||
simplifyIfs' (_optOptimizationLevel <= 1)
|
||||
simplifyArithmetic
|
||||
. simplifyIfs' (_optOptimizationLevel <= 1)
|
||||
. simplifyComparisons
|
||||
. caseFolding
|
||||
. casePermutation
|
||||
|
@ -0,0 +1,64 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyArithmetic (simplifyArithmetic) where
|
||||
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
convertNode :: Node -> Node
|
||||
convertNode = dmap go
|
||||
where
|
||||
go :: Node -> Node
|
||||
go node = case node of
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntAdd,
|
||||
[NBlt blt', n] <- _builtinAppArgs,
|
||||
blt' ^. builtinAppOp == OpIntSub,
|
||||
[x, m] <- blt' ^. builtinAppArgs,
|
||||
m == n ->
|
||||
x
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntSub,
|
||||
[NBlt blt', n] <- _builtinAppArgs,
|
||||
blt' ^. builtinAppOp == OpIntAdd,
|
||||
[x, m] <- blt' ^. builtinAppArgs ->
|
||||
if
|
||||
| m == n ->
|
||||
x
|
||||
| x == n ->
|
||||
m
|
||||
| otherwise ->
|
||||
node
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntAdd,
|
||||
[n, NBlt blt'] <- _builtinAppArgs,
|
||||
blt' ^. builtinAppOp == OpIntSub,
|
||||
[x, m] <- blt' ^. builtinAppArgs,
|
||||
m == n ->
|
||||
x
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntAdd || _builtinAppOp == OpIntSub,
|
||||
[x, NCst (Constant _ (ConstInteger 0))] <- _builtinAppArgs ->
|
||||
x
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntAdd,
|
||||
[NCst (Constant _ (ConstInteger 0)), x] <- _builtinAppArgs ->
|
||||
x
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntMul,
|
||||
[_, c@(NCst (Constant _ (ConstInteger 0)))] <- _builtinAppArgs ->
|
||||
c
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntMul,
|
||||
[c@(NCst (Constant _ (ConstInteger 0))), _] <- _builtinAppArgs ->
|
||||
c
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntMul,
|
||||
[x, NCst (Constant _ (ConstInteger 1))] <- _builtinAppArgs ->
|
||||
x
|
||||
NBlt BuiltinApp {..}
|
||||
| _builtinAppOp == OpIntMul,
|
||||
[NCst (Constant _ (ConstInteger 1)), x] <- _builtinAppArgs ->
|
||||
x
|
||||
_ -> node
|
||||
|
||||
simplifyArithmetic :: InfoTable -> InfoTable
|
||||
simplifyArithmetic = mapAllNodes convertNode
|
@ -382,5 +382,10 @@ tests =
|
||||
"Test064: Constant folding"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test064.juvix")
|
||||
$(mkRelFile "out/test064.out")
|
||||
$(mkRelFile "out/test064.out"),
|
||||
posTest
|
||||
"Test065: Arithmetic simplification"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test065.juvix")
|
||||
$(mkRelFile "out/test065.out")
|
||||
]
|
||||
|
1
tests/Compilation/positive/out/test065.out
Normal file
1
tests/Compilation/positive/out/test065.out
Normal file
@ -0,0 +1 @@
|
||||
42
|
15
tests/Compilation/positive/test065.juvix
Normal file
15
tests/Compilation/positive/test065.juvix
Normal file
@ -0,0 +1,15 @@
|
||||
-- Arithmetic simplification
|
||||
module test065;
|
||||
|
||||
import Stdlib.Prelude open;
|
||||
|
||||
{-# inline: false #-}
|
||||
f (x : Int) : Int :=
|
||||
(x + fromNat 1 - fromNat 1) * fromNat 1
|
||||
+ fromNat 0 * x
|
||||
+ (fromNat 10 + (x - fromNat 10))
|
||||
+ (fromNat 10 + x - fromNat 10)
|
||||
+ (fromNat 11 + (fromNat 11 - x))
|
||||
+ fromNat 1 * x * fromNat 0 * fromNat 1;
|
||||
|
||||
main : Int := f (fromNat 10);
|
Loading…
Reference in New Issue
Block a user