1
1
mirror of https://github.com/anoma/juvix.git synced 2024-07-07 04:36:19 +03:00

Peephole optimization of Cairo assembly (#2858)

* Closes #2703 
* Adds [peephole
optimization](https://en.wikipedia.org/wiki/Peephole_optimization) of
Cairo assembly.
* Adds a transformation framework for the CASM IR.
* Adds `--transforms`, `--run` and `--no-print` options to the `dev casm
read` command.
This commit is contained in:
Łukasz Czajka 2024-06-27 12:41:27 +02:00 committed by GitHub
parent 4dcbb002fe
commit 802d82f22e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 327 additions and 38 deletions

View File

@ -2,7 +2,11 @@ module Commands.Dev.Casm.Read where
import Commands.Base
import Commands.Dev.Casm.Read.Options
import Juvix.Compiler.Casm.Pretty qualified as Casm
import Juvix.Compiler.Casm.Data.InputInfo qualified as Casm
import Juvix.Compiler.Casm.Extra.LabelInfo qualified as Casm
import Juvix.Compiler.Casm.Interpreter qualified as Casm
import Juvix.Compiler.Casm.Pretty qualified as Casm.Pretty
import Juvix.Compiler.Casm.Transformation qualified as Casm
import Juvix.Compiler.Casm.Translation.FromSource qualified as Casm
import Juvix.Compiler.Casm.Validate qualified as Casm
@ -15,7 +19,28 @@ runCommand opts = do
Right (labi, code) ->
case Casm.validate labi code of
Left err -> exitJuvixError (JuvixError err)
Right () -> renderStdOut (Casm.ppProgram code)
Right () -> do
r <-
runError @JuvixError
. runReader Casm.defaultOptions
$ (Casm.applyTransformations (project opts ^. casmReadTransformations) code)
case r of
Left err -> exitJuvixError (JuvixError err)
Right code' -> do
unless (project opts ^. casmReadNoPrint) $
renderStdOut (Casm.Pretty.ppProgram code')
doRun code'
where
file :: AppPath File
file = opts ^. casmReadInputFile
doRun :: Casm.Code -> Sem r ()
doRun code'
| project opts ^. casmReadRun = do
putStrLn "--------------------------------"
putStrLn "| Run |"
putStrLn "--------------------------------"
let labi = Casm.computeLabelInfo code'
inputInfo = Casm.InputInfo mempty
print (Casm.runCode inputInfo labi code')
| otherwise = return ()

View File

@ -1,9 +1,13 @@
module Commands.Dev.Casm.Read.Options where
import CommonOptions
import Juvix.Compiler.Casm.Data.TransformationId
newtype CasmReadOptions = CasmReadOptions
{ _casmReadInputFile :: AppPath File
data CasmReadOptions = CasmReadOptions
{ _casmReadTransformations :: [TransformationId],
_casmReadRun :: Bool,
_casmReadNoPrint :: Bool,
_casmReadInputFile :: AppPath File
}
deriving stock (Data)
@ -11,5 +15,8 @@ makeLenses ''CasmReadOptions
parseCasmReadOptions :: Parser CasmReadOptions
parseCasmReadOptions = do
_casmReadNoPrint <- optReadNoPrint
_casmReadRun <- optReadRun
_casmReadTransformations <- optCasmTransformationIds
_casmReadInputFile <- parseInputFile FileExtCasm
pure CasmReadOptions {..}

View File

@ -15,16 +15,8 @@ makeLenses ''RegReadOptions
parseRegReadOptions :: Parser RegReadOptions
parseRegReadOptions = do
_regReadNoPrint <-
switch
( long "no-print"
<> help "Do not print the transformed code"
)
_regReadRun <-
switch
( long "run"
<> help "Run the code after the transformation"
)
_regReadNoPrint <- optReadNoPrint
_regReadRun <- optReadRun
_regReadTransformations <- optRegTransformationIds
_regReadInputFile <- parseInputFile FileExtJuvixReg
pure RegReadOptions {..}

View File

@ -10,6 +10,7 @@ where
import Control.Exception qualified as GHC
import Data.List.NonEmpty qualified as NonEmpty
import GHC.Conc
import Juvix.Compiler.Casm.Data.TransformationId.Parser qualified as Casm
import Juvix.Compiler.Concrete.Translation.ImportScanner
import Juvix.Compiler.Core.Data.TransformationId.Parser qualified as Core
import Juvix.Compiler.Pipeline.EntryPoint
@ -282,6 +283,20 @@ optNoDisambiguate =
<> help "Don't disambiguate the names of bound variables"
)
optReadRun :: Parser Bool
optReadRun =
switch
( long "run"
<> help "Run the code after the transformation"
)
optReadNoPrint :: Parser Bool
optReadNoPrint =
switch
( long "no-print"
<> help "Do not print the transformed code"
)
optTransformationIds :: forall a. (Text -> Either Text [a]) -> (String -> [String]) -> Parser [a]
optTransformationIds parseIds completions =
option
@ -317,6 +332,9 @@ optTreeTransformationIds = optTransformationIds Tree.parseTransformations Tree.c
optRegTransformationIds :: Parser [Reg.TransformationId]
optRegTransformationIds = optTransformationIds Reg.parseTransformations Reg.completionsString
optCasmTransformationIds :: Parser [Casm.TransformationId]
optCasmTransformationIds = optTransformationIds Casm.parseTransformations Casm.completionsString
class EntryPointOptions a where
applyOptions :: a -> EntryPoint -> EntryPoint

View File

@ -114,6 +114,9 @@ juvix_ec_op:
-- [fp - 3]: closure
-- [fp - 4]: n = the number of arguments to extend with
-- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based)
-- On return:
-- [ap - 1]: new closure
-- This procedure doesn't accept or return the builtins pointer.
juvix_extend_closure:
-- copy stored args reversing them;
-- to copy the stored args to the new closure

View File

@ -0,0 +1,34 @@
module Juvix.Compiler.Casm.Data.TransformationId where
import Juvix.Compiler.Casm.Data.TransformationId.Strings
import Juvix.Compiler.Core.Data.TransformationId.Base
import Juvix.Prelude
data TransformationId
= IdentityTrans
| Peephole
deriving stock (Data, Bounded, Enum, Show)
data PipelineId
= PipelineCairo
deriving stock (Data, Bounded, Enum)
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId
toCairoTransformations :: [TransformationId]
toCairoTransformations = [Peephole]
instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
IdentityTrans -> strIdentity
Peephole -> strPeephole
instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
pipelineText = \case
PipelineCairo -> strCairoPipeline
pipeline :: PipelineId -> [TransformationId]
pipeline = \case
PipelineCairo -> toCairoTransformations

View File

@ -0,0 +1,14 @@
module Juvix.Compiler.Casm.Data.TransformationId.Parser (parseTransformations, TransformationId (..), completions, completionsString) where
import Juvix.Compiler.Casm.Data.TransformationId
import Juvix.Compiler.Core.Data.TransformationId.Parser.Base
import Juvix.Prelude
parseTransformations :: Text -> Either Text [TransformationId]
parseTransformations = parseTransformations' @TransformationId @PipelineId
completionsString :: String -> [String]
completionsString = completionsString' @TransformationId @PipelineId
completions :: Text -> [Text]
completions = completions' @TransformationId @PipelineId

View File

@ -0,0 +1,12 @@
module Juvix.Compiler.Casm.Data.TransformationId.Strings where
import Juvix.Prelude
strCairoPipeline :: Text
strCairoPipeline = "pipeline-cairo"
strIdentity :: Text
strIdentity = "identity"
strPeephole :: Text
strPeephole = "peephole"

View File

@ -0,0 +1,17 @@
module Juvix.Compiler.Casm.Pipeline
( module Juvix.Compiler.Casm.Pipeline,
Options,
Code,
)
where
import Juvix.Compiler.Casm.Transformation
import Juvix.Compiler.Pipeline.EntryPoint (EntryPoint)
-- | Perform transformations on CASM necessary before the translation to Cairo
-- bytecode
toCairo' :: Code -> Sem r Code
toCairo' = applyTransformations toCairoTransformations
toCairo :: (Member (Reader EntryPoint) r) => Code -> Sem r Code
toCairo = mapReader fromEntryPoint . toCairo'

View File

@ -0,0 +1,18 @@
module Juvix.Compiler.Casm.Transformation
( module Juvix.Compiler.Casm.Transformation.Base,
module Juvix.Compiler.Casm.Transformation,
module Juvix.Compiler.Casm.Data.TransformationId,
)
where
import Juvix.Compiler.Casm.Data.TransformationId
import Juvix.Compiler.Casm.Transformation.Base
import Juvix.Compiler.Casm.Transformation.Optimize.Peephole
applyTransformations :: forall r. [TransformationId] -> Code -> Sem r Code
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
where
appTrans :: TransformationId -> Code -> Sem r Code
appTrans = \case
IdentityTrans -> return
Peephole -> return . peephole

View File

@ -0,0 +1,17 @@
module Juvix.Compiler.Casm.Transformation.Base
( module Juvix.Compiler.Casm.Transformation.Base,
module Juvix.Compiler.Casm.Language,
module Juvix.Compiler.Tree.Options,
)
where
import Juvix.Compiler.Casm.Language
import Juvix.Compiler.Tree.Options
mapT :: ([Instruction] -> [Instruction]) -> [Instruction] -> [Instruction]
mapT f = go
where
go :: [Instruction] -> [Instruction]
go = \case
i : is -> f (i : go is)
[] -> f []

View File

@ -0,0 +1,78 @@
module Juvix.Compiler.Casm.Transformation.Optimize.Peephole where
import Juvix.Compiler.Casm.Extra.Base
import Juvix.Compiler.Casm.Language
import Juvix.Compiler.Casm.Transformation.Base
peephole :: [Instruction] -> [Instruction]
peephole = mapT go
where
go :: [Instruction] -> [Instruction]
go = \case
Nop : is -> is
Jump InstrJump {..} : lab@(Label LabelRef {..}) : is
| not _instrJumpIncAp,
Val (Lab (LabelRef sym _)) <- _instrJumpTarget,
sym == _labelRefSymbol ->
lab : is
Call InstrCall {..} : Return : Assign a1 : Return : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Just k1 <- getAssignApFp a1 ->
fixAssignAp $
mkAssignAp (Val (Ref (MemRef Ap k1)))
: Return
: is
Call InstrCall {..} : Return : Assign a1 : Assign a2 : Return : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Just k1 <- getAssignApFp a1,
Just k2 <- getAssignApFp a2 ->
fixAssignAp $
mkAssignAp (Val (Ref (MemRef Ap k1)))
: mkAssignAp (Val (Ref (MemRef Ap (k2 - 1))))
: Return
: is
Call InstrCall {..} : Return : Jump InstrJump {..} : is
| _instrCallRel,
Imm 3 <- _instrCallTarget,
Val tgt@(Lab {}) <- _instrJumpTarget,
not _instrJumpIncAp ->
let call =
InstrCall
{ _instrCallTarget = tgt,
_instrCallRel = _instrJumpRel
}
in Call call : Return : is
is -> is
fixAssignAp :: [Instruction] -> [Instruction]
fixAssignAp = \case
Assign a : Return : is
| Just (-1) <- getAssignAp Ap a ->
Return : is
Assign a1 : Assign a2 : Return : is
| Just (-2) <- getAssignAp Ap a1,
Just (-2) <- getAssignAp Ap a2 ->
Return : is
Assign a1 : Assign a2 : Return : is
| Just (-1) <- getAssignAp Ap a1,
Just (-3) <- getAssignAp Ap a2 ->
mkAssignAp (Val (Ref (MemRef Ap (-2)))) : Return : is
is -> is
getAssignAp :: Reg -> InstrAssign -> Maybe Offset
getAssignAp reg InstrAssign {..}
| MemRef Ap 0 <- _instrAssignResult,
Val (Ref (MemRef r k)) <- _instrAssignValue,
r == reg,
_instrAssignIncAp =
Just k
| otherwise =
Nothing
getAssignApFp :: InstrAssign -> Maybe Offset
getAssignApFp instr = case getAssignAp Fp instr of
Just k
| k <= -3 -> Just (k + 2)
_ -> Nothing

View File

@ -164,16 +164,18 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
-- To ensure that memory is accessed sequentially at all times, we divide
-- instructions into basic blocks. Within each basic block, the `ap` offset
-- is known at each instruction, which allows to statically associate `fp`
-- offsets to variables while still generating only sequential assignments
-- to `[ap]` with increasing `ap`. When the `ap` offset can no longer be
-- statically determined for new variables (e.g. due to an intervening
-- recursive call), we switch to the next basic block by "calling" it with
-- the `call` instruction (see `goCallBlock`). The arguments of the basic
-- block call are the variables live at the beginning of the block. Note
-- that the `fp` offsets of "old" variables are still statically determined
-- even after the current `ap` offset becomes unknown -- the arbitrary
-- increase of `ap` does not influence the previous variable associations.
-- (i.e. how much `ap` increased since the start of the basic block) is
-- known at each instruction, which allows to statically associate `fp`
-- offsets (i.e. offsets relative to `fp`) to variables while still
-- generating only sequential assignments to `[ap]` with increasing `ap`.
-- When the `ap` offset can no longer be statically determined for new
-- variables (e.g. due to an intervening recursive call), we switch to the
-- next basic block by "calling" it with the `call` instruction (see
-- `goCallBlock`). The arguments of the basic block call are the variables
-- live at the beginning of the block. Note that the `fp` offsets of "old"
-- variables are still statically determined even after the current `ap`
-- offset becomes unknown -- the arbitrary increase of `ap` does not
-- influence the previous variable associations.
goBlock :: forall r. (Members '[LabelInfoBuilder, CasmBuilder, Output Instruction] r) => StdlibBuiltins -> LabelRef -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r ()
goBlock blts failLab liveVars0 mout Reg.Block {..} = do
mapM_ goInstr _blockBody
@ -645,7 +647,10 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
ap0 <- getAP
vars <- getVars
bltOff <- getBuiltinOffset
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches
-- reversing `_instrCaseBranches` typically results in better
-- opportunities for peephole optimization (the last jump to branch
-- may be removed by the peephole optimizer)
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) (reverse _instrCaseBranches)
mapM_ (goDefaultLabel symMap) defaultTags
whenJust _instrCaseDefault $
goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar

View File

@ -23,6 +23,7 @@ import Juvix.Compiler.Backend.Rust.Translation.FromReg qualified as Rust
import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR
import Juvix.Compiler.Casm.Data.Builtins qualified as Casm
import Juvix.Compiler.Casm.Data.Result qualified as Casm
import Juvix.Compiler.Casm.Pipeline qualified as Casm
import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm
import Juvix.Compiler.Concrete.Data.Highlight.Input
import Juvix.Compiler.Concrete.Language
@ -364,17 +365,25 @@ regToCasm = Reg.toCasm >=> return . Casm.fromReg
regToCasm' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Casm.Result
regToCasm' = Reg.toCasm' >=> return . Casm.fromReg
casmToCairo :: Casm.Result -> Sem r Cairo.Result
casmToCairo Casm.Result {..} =
casmToCairo :: (Member (Reader EntryPoint) r) => Casm.Result -> Sem r Cairo.Result
casmToCairo Casm.Result {..} = do
code' <- Casm.toCairo _resultCode
return
. Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins)
$ Cairo.fromCasm _resultCode
$ Cairo.fromCasm code'
casmToCairo' :: Casm.Result -> Sem r Cairo.Result
casmToCairo' Casm.Result {..} = do
code' <- Casm.toCairo' _resultCode
return
. Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins)
$ Cairo.fromCasm code'
regToCairo :: (Member (Reader EntryPoint) r) => Reg.InfoTable -> Sem r Cairo.Result
regToCairo = regToCasm >=> casmToCairo
regToCairo' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Cairo.Result
regToCairo' = regToCasm' >=> casmToCairo
regToCairo' = regToCasm' >=> casmToCairo'
treeToAnoma' :: (Members '[Error JuvixError, Reader NockmaTree.CompilerOptions] r) => Tree.InfoTable -> Sem r NockmaTree.AnomaResult
treeToAnoma' = Tree.toNockma >=> NockmaTree.fromTreeTable

View File

@ -9,6 +9,7 @@ import Juvix.Compiler.Casm.Extra.InputInfo
import Juvix.Compiler.Casm.Interpreter
import Juvix.Compiler.Casm.Translation.FromSource
import Juvix.Compiler.Casm.Validate
import Juvix.Compiler.Tree.Options qualified as Casm
import Juvix.Data.Field
import Juvix.Data.PPOutput
import Juvix.Parser.Error
@ -27,14 +28,15 @@ casmRunVM labi instrs blts outputSize inputFile expectedFile step = do
step "Serialize to Cairo bytecode"
let res =
run $
casmToCairo
( Casm.Result
{ _resultLabelInfo = labi,
_resultCode = instrs,
_resultBuiltins = blts,
_resultOutputSize = outputSize
}
)
runReader Casm.defaultOptions $
casmToCairo'
( Casm.Result
{ _resultLabelInfo = labi,
_resultCode = instrs,
_resultBuiltins = blts,
_resultOutputSize = outputSize
}
)
outputFile = dirPath <//> $(mkRelFile "out.json")
encodeFile (toFilePath outputFile) res
step "Run Cairo VM"

View File

@ -166,5 +166,13 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test016.casm")
$(mkRelFile "out/test016.out")
(Just $(mkRelFile "in/test016.json"))
(Just $(mkRelFile "in/test016.json")),
PosTest
"Test017: Peephole optimization"
True
True
$(mkRelDir ".")
$(mkRelFile "test017.casm")
$(mkRelFile "out/test017.out")
Nothing
]

View File

@ -0,0 +1 @@
7

View File

@ -0,0 +1,29 @@
-- peephole optimization
call main
jmp end
main:
jmp lab_1
lab_1:
nop
nop
jmp lab_2
nop
nop
lab_2:
call rel 3
ret
nop
jmp lab_3
lab_4:
[ap] = 7; ap++
call rel 3
ret
nop
[ap] = [fp - 3]; ap++
ret
lab_3:
jmp lab_4
end: