diff --git a/app/Commands/Dev/Casm/Read.hs b/app/Commands/Dev/Casm/Read.hs index cc0008b99..6beac820d 100644 --- a/app/Commands/Dev/Casm/Read.hs +++ b/app/Commands/Dev/Casm/Read.hs @@ -2,7 +2,11 @@ module Commands.Dev.Casm.Read where import Commands.Base import Commands.Dev.Casm.Read.Options -import Juvix.Compiler.Casm.Pretty qualified as Casm +import Juvix.Compiler.Casm.Data.InputInfo qualified as Casm +import Juvix.Compiler.Casm.Extra.LabelInfo qualified as Casm +import Juvix.Compiler.Casm.Interpreter qualified as Casm +import Juvix.Compiler.Casm.Pretty qualified as Casm.Pretty +import Juvix.Compiler.Casm.Transformation qualified as Casm import Juvix.Compiler.Casm.Translation.FromSource qualified as Casm import Juvix.Compiler.Casm.Validate qualified as Casm @@ -15,7 +19,28 @@ runCommand opts = do Right (labi, code) -> case Casm.validate labi code of Left err -> exitJuvixError (JuvixError err) - Right () -> renderStdOut (Casm.ppProgram code) + Right () -> do + r <- + runError @JuvixError + . runReader Casm.defaultOptions + $ (Casm.applyTransformations (project opts ^. casmReadTransformations) code) + case r of + Left err -> exitJuvixError (JuvixError err) + Right code' -> do + unless (project opts ^. casmReadNoPrint) $ + renderStdOut (Casm.Pretty.ppProgram code') + doRun code' where file :: AppPath File file = opts ^. casmReadInputFile + + doRun :: Casm.Code -> Sem r () + doRun code' + | project opts ^. casmReadRun = do + putStrLn "--------------------------------" + putStrLn "| Run |" + putStrLn "--------------------------------" + let labi = Casm.computeLabelInfo code' + inputInfo = Casm.InputInfo mempty + print (Casm.runCode inputInfo labi code') + | otherwise = return () diff --git a/app/Commands/Dev/Casm/Read/Options.hs b/app/Commands/Dev/Casm/Read/Options.hs index 311df1b45..797c09af3 100644 --- a/app/Commands/Dev/Casm/Read/Options.hs +++ b/app/Commands/Dev/Casm/Read/Options.hs @@ -1,9 +1,13 @@ module Commands.Dev.Casm.Read.Options where import CommonOptions +import Juvix.Compiler.Casm.Data.TransformationId -newtype CasmReadOptions = CasmReadOptions - { _casmReadInputFile :: AppPath File +data CasmReadOptions = CasmReadOptions + { _casmReadTransformations :: [TransformationId], + _casmReadRun :: Bool, + _casmReadNoPrint :: Bool, + _casmReadInputFile :: AppPath File } deriving stock (Data) @@ -11,5 +15,8 @@ makeLenses ''CasmReadOptions parseCasmReadOptions :: Parser CasmReadOptions parseCasmReadOptions = do + _casmReadNoPrint <- optReadNoPrint + _casmReadRun <- optReadRun + _casmReadTransformations <- optCasmTransformationIds _casmReadInputFile <- parseInputFile FileExtCasm pure CasmReadOptions {..} diff --git a/app/Commands/Dev/Reg/Read/Options.hs b/app/Commands/Dev/Reg/Read/Options.hs index a19789fe0..7e0d5c450 100644 --- a/app/Commands/Dev/Reg/Read/Options.hs +++ b/app/Commands/Dev/Reg/Read/Options.hs @@ -15,16 +15,8 @@ makeLenses ''RegReadOptions parseRegReadOptions :: Parser RegReadOptions parseRegReadOptions = do - _regReadNoPrint <- - switch - ( long "no-print" - <> help "Do not print the transformed code" - ) - _regReadRun <- - switch - ( long "run" - <> help "Run the code after the transformation" - ) + _regReadNoPrint <- optReadNoPrint + _regReadRun <- optReadRun _regReadTransformations <- optRegTransformationIds _regReadInputFile <- parseInputFile FileExtJuvixReg pure RegReadOptions {..} diff --git a/app/CommonOptions.hs b/app/CommonOptions.hs index b3d53216a..c2fba8373 100644 --- a/app/CommonOptions.hs +++ b/app/CommonOptions.hs @@ -10,6 +10,7 @@ where import Control.Exception qualified as GHC import Data.List.NonEmpty qualified as NonEmpty import GHC.Conc +import Juvix.Compiler.Casm.Data.TransformationId.Parser qualified as Casm import Juvix.Compiler.Concrete.Translation.ImportScanner import Juvix.Compiler.Core.Data.TransformationId.Parser qualified as Core import Juvix.Compiler.Pipeline.EntryPoint @@ -282,6 +283,20 @@ optNoDisambiguate = <> help "Don't disambiguate the names of bound variables" ) +optReadRun :: Parser Bool +optReadRun = + switch + ( long "run" + <> help "Run the code after the transformation" + ) + +optReadNoPrint :: Parser Bool +optReadNoPrint = + switch + ( long "no-print" + <> help "Do not print the transformed code" + ) + optTransformationIds :: forall a. (Text -> Either Text [a]) -> (String -> [String]) -> Parser [a] optTransformationIds parseIds completions = option @@ -317,6 +332,9 @@ optTreeTransformationIds = optTransformationIds Tree.parseTransformations Tree.c optRegTransformationIds :: Parser [Reg.TransformationId] optRegTransformationIds = optTransformationIds Reg.parseTransformations Reg.completionsString +optCasmTransformationIds :: Parser [Casm.TransformationId] +optCasmTransformationIds = optTransformationIds Casm.parseTransformations Casm.completionsString + class EntryPointOptions a where applyOptions :: a -> EntryPoint -> EntryPoint diff --git a/runtime/casm/stdlib.casm b/runtime/casm/stdlib.casm index bcc37476f..1e8730749 100644 --- a/runtime/casm/stdlib.casm +++ b/runtime/casm/stdlib.casm @@ -114,6 +114,9 @@ juvix_ec_op: -- [fp - 3]: closure -- [fp - 4]: n = the number of arguments to extend with -- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based) +-- On return: +-- [ap - 1]: new closure +-- This procedure doesn't accept or return the builtins pointer. juvix_extend_closure: -- copy stored args reversing them; -- to copy the stored args to the new closure diff --git a/src/Juvix/Compiler/Casm/Data/TransformationId.hs b/src/Juvix/Compiler/Casm/Data/TransformationId.hs new file mode 100644 index 000000000..0c275c251 --- /dev/null +++ b/src/Juvix/Compiler/Casm/Data/TransformationId.hs @@ -0,0 +1,34 @@ +module Juvix.Compiler.Casm.Data.TransformationId where + +import Juvix.Compiler.Casm.Data.TransformationId.Strings +import Juvix.Compiler.Core.Data.TransformationId.Base +import Juvix.Prelude + +data TransformationId + = IdentityTrans + | Peephole + deriving stock (Data, Bounded, Enum, Show) + +data PipelineId + = PipelineCairo + deriving stock (Data, Bounded, Enum) + +type TransformationLikeId = TransformationLikeId' TransformationId PipelineId + +toCairoTransformations :: [TransformationId] +toCairoTransformations = [Peephole] + +instance TransformationId' TransformationId where + transformationText :: TransformationId -> Text + transformationText = \case + IdentityTrans -> strIdentity + Peephole -> strPeephole + +instance PipelineId' TransformationId PipelineId where + pipelineText :: PipelineId -> Text + pipelineText = \case + PipelineCairo -> strCairoPipeline + + pipeline :: PipelineId -> [TransformationId] + pipeline = \case + PipelineCairo -> toCairoTransformations diff --git a/src/Juvix/Compiler/Casm/Data/TransformationId/Parser.hs b/src/Juvix/Compiler/Casm/Data/TransformationId/Parser.hs new file mode 100644 index 000000000..00e29f96b --- /dev/null +++ b/src/Juvix/Compiler/Casm/Data/TransformationId/Parser.hs @@ -0,0 +1,14 @@ +module Juvix.Compiler.Casm.Data.TransformationId.Parser (parseTransformations, TransformationId (..), completions, completionsString) where + +import Juvix.Compiler.Casm.Data.TransformationId +import Juvix.Compiler.Core.Data.TransformationId.Parser.Base +import Juvix.Prelude + +parseTransformations :: Text -> Either Text [TransformationId] +parseTransformations = parseTransformations' @TransformationId @PipelineId + +completionsString :: String -> [String] +completionsString = completionsString' @TransformationId @PipelineId + +completions :: Text -> [Text] +completions = completions' @TransformationId @PipelineId diff --git a/src/Juvix/Compiler/Casm/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Casm/Data/TransformationId/Strings.hs new file mode 100644 index 000000000..a278b5d90 --- /dev/null +++ b/src/Juvix/Compiler/Casm/Data/TransformationId/Strings.hs @@ -0,0 +1,12 @@ +module Juvix.Compiler.Casm.Data.TransformationId.Strings where + +import Juvix.Prelude + +strCairoPipeline :: Text +strCairoPipeline = "pipeline-cairo" + +strIdentity :: Text +strIdentity = "identity" + +strPeephole :: Text +strPeephole = "peephole" diff --git a/src/Juvix/Compiler/Casm/Pipeline.hs b/src/Juvix/Compiler/Casm/Pipeline.hs new file mode 100644 index 000000000..3772ed5a4 --- /dev/null +++ b/src/Juvix/Compiler/Casm/Pipeline.hs @@ -0,0 +1,17 @@ +module Juvix.Compiler.Casm.Pipeline + ( module Juvix.Compiler.Casm.Pipeline, + Options, + Code, + ) +where + +import Juvix.Compiler.Casm.Transformation +import Juvix.Compiler.Pipeline.EntryPoint (EntryPoint) + +-- | Perform transformations on CASM necessary before the translation to Cairo +-- bytecode +toCairo' :: Code -> Sem r Code +toCairo' = applyTransformations toCairoTransformations + +toCairo :: (Member (Reader EntryPoint) r) => Code -> Sem r Code +toCairo = mapReader fromEntryPoint . toCairo' diff --git a/src/Juvix/Compiler/Casm/Transformation.hs b/src/Juvix/Compiler/Casm/Transformation.hs new file mode 100644 index 000000000..285855ff4 --- /dev/null +++ b/src/Juvix/Compiler/Casm/Transformation.hs @@ -0,0 +1,18 @@ +module Juvix.Compiler.Casm.Transformation + ( module Juvix.Compiler.Casm.Transformation.Base, + module Juvix.Compiler.Casm.Transformation, + module Juvix.Compiler.Casm.Data.TransformationId, + ) +where + +import Juvix.Compiler.Casm.Data.TransformationId +import Juvix.Compiler.Casm.Transformation.Base +import Juvix.Compiler.Casm.Transformation.Optimize.Peephole + +applyTransformations :: forall r. [TransformationId] -> Code -> Sem r Code +applyTransformations ts tbl = foldM (flip appTrans) tbl ts + where + appTrans :: TransformationId -> Code -> Sem r Code + appTrans = \case + IdentityTrans -> return + Peephole -> return . peephole diff --git a/src/Juvix/Compiler/Casm/Transformation/Base.hs b/src/Juvix/Compiler/Casm/Transformation/Base.hs new file mode 100644 index 000000000..fdcbb1b9a --- /dev/null +++ b/src/Juvix/Compiler/Casm/Transformation/Base.hs @@ -0,0 +1,17 @@ +module Juvix.Compiler.Casm.Transformation.Base + ( module Juvix.Compiler.Casm.Transformation.Base, + module Juvix.Compiler.Casm.Language, + module Juvix.Compiler.Tree.Options, + ) +where + +import Juvix.Compiler.Casm.Language +import Juvix.Compiler.Tree.Options + +mapT :: ([Instruction] -> [Instruction]) -> [Instruction] -> [Instruction] +mapT f = go + where + go :: [Instruction] -> [Instruction] + go = \case + i : is -> f (i : go is) + [] -> f [] diff --git a/src/Juvix/Compiler/Casm/Transformation/Optimize/Peephole.hs b/src/Juvix/Compiler/Casm/Transformation/Optimize/Peephole.hs new file mode 100644 index 000000000..d934de571 --- /dev/null +++ b/src/Juvix/Compiler/Casm/Transformation/Optimize/Peephole.hs @@ -0,0 +1,78 @@ +module Juvix.Compiler.Casm.Transformation.Optimize.Peephole where + +import Juvix.Compiler.Casm.Extra.Base +import Juvix.Compiler.Casm.Language +import Juvix.Compiler.Casm.Transformation.Base + +peephole :: [Instruction] -> [Instruction] +peephole = mapT go + where + go :: [Instruction] -> [Instruction] + go = \case + Nop : is -> is + Jump InstrJump {..} : lab@(Label LabelRef {..}) : is + | not _instrJumpIncAp, + Val (Lab (LabelRef sym _)) <- _instrJumpTarget, + sym == _labelRefSymbol -> + lab : is + Call InstrCall {..} : Return : Assign a1 : Return : is + | _instrCallRel, + Imm 3 <- _instrCallTarget, + Just k1 <- getAssignApFp a1 -> + fixAssignAp $ + mkAssignAp (Val (Ref (MemRef Ap k1))) + : Return + : is + Call InstrCall {..} : Return : Assign a1 : Assign a2 : Return : is + | _instrCallRel, + Imm 3 <- _instrCallTarget, + Just k1 <- getAssignApFp a1, + Just k2 <- getAssignApFp a2 -> + fixAssignAp $ + mkAssignAp (Val (Ref (MemRef Ap k1))) + : mkAssignAp (Val (Ref (MemRef Ap (k2 - 1)))) + : Return + : is + Call InstrCall {..} : Return : Jump InstrJump {..} : is + | _instrCallRel, + Imm 3 <- _instrCallTarget, + Val tgt@(Lab {}) <- _instrJumpTarget, + not _instrJumpIncAp -> + let call = + InstrCall + { _instrCallTarget = tgt, + _instrCallRel = _instrJumpRel + } + in Call call : Return : is + is -> is + + fixAssignAp :: [Instruction] -> [Instruction] + fixAssignAp = \case + Assign a : Return : is + | Just (-1) <- getAssignAp Ap a -> + Return : is + Assign a1 : Assign a2 : Return : is + | Just (-2) <- getAssignAp Ap a1, + Just (-2) <- getAssignAp Ap a2 -> + Return : is + Assign a1 : Assign a2 : Return : is + | Just (-1) <- getAssignAp Ap a1, + Just (-3) <- getAssignAp Ap a2 -> + mkAssignAp (Val (Ref (MemRef Ap (-2)))) : Return : is + is -> is + + getAssignAp :: Reg -> InstrAssign -> Maybe Offset + getAssignAp reg InstrAssign {..} + | MemRef Ap 0 <- _instrAssignResult, + Val (Ref (MemRef r k)) <- _instrAssignValue, + r == reg, + _instrAssignIncAp = + Just k + | otherwise = + Nothing + + getAssignApFp :: InstrAssign -> Maybe Offset + getAssignApFp instr = case getAssignAp Fp instr of + Just k + | k <= -3 -> Just (k + 2) + _ -> Nothing diff --git a/src/Juvix/Compiler/Casm/Translation/FromReg.hs b/src/Juvix/Compiler/Casm/Translation/FromReg.hs index b0876c002..9e3f3b288 100644 --- a/src/Juvix/Compiler/Casm/Translation/FromReg.hs +++ b/src/Juvix/Compiler/Casm/Translation/FromReg.hs @@ -164,16 +164,18 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI -- To ensure that memory is accessed sequentially at all times, we divide -- instructions into basic blocks. Within each basic block, the `ap` offset - -- is known at each instruction, which allows to statically associate `fp` - -- offsets to variables while still generating only sequential assignments - -- to `[ap]` with increasing `ap`. When the `ap` offset can no longer be - -- statically determined for new variables (e.g. due to an intervening - -- recursive call), we switch to the next basic block by "calling" it with - -- the `call` instruction (see `goCallBlock`). The arguments of the basic - -- block call are the variables live at the beginning of the block. Note - -- that the `fp` offsets of "old" variables are still statically determined - -- even after the current `ap` offset becomes unknown -- the arbitrary - -- increase of `ap` does not influence the previous variable associations. + -- (i.e. how much `ap` increased since the start of the basic block) is + -- known at each instruction, which allows to statically associate `fp` + -- offsets (i.e. offsets relative to `fp`) to variables while still + -- generating only sequential assignments to `[ap]` with increasing `ap`. + -- When the `ap` offset can no longer be statically determined for new + -- variables (e.g. due to an intervening recursive call), we switch to the + -- next basic block by "calling" it with the `call` instruction (see + -- `goCallBlock`). The arguments of the basic block call are the variables + -- live at the beginning of the block. Note that the `fp` offsets of "old" + -- variables are still statically determined even after the current `ap` + -- offset becomes unknown -- the arbitrary increase of `ap` does not + -- influence the previous variable associations. goBlock :: forall r. (Members '[LabelInfoBuilder, CasmBuilder, Output Instruction] r) => StdlibBuiltins -> LabelRef -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r () goBlock blts failLab liveVars0 mout Reg.Block {..} = do mapM_ goInstr _blockBody @@ -645,7 +647,10 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI ap0 <- getAP vars <- getVars bltOff <- getBuiltinOffset - mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches + -- reversing `_instrCaseBranches` typically results in better + -- opportunities for peephole optimization (the last jump to branch + -- may be removed by the peephole optimizer) + mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) (reverse _instrCaseBranches) mapM_ (goDefaultLabel symMap) defaultTags whenJust _instrCaseDefault $ goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar diff --git a/src/Juvix/Compiler/Pipeline.hs b/src/Juvix/Compiler/Pipeline.hs index 58cd880cb..4ea0ab455 100644 --- a/src/Juvix/Compiler/Pipeline.hs +++ b/src/Juvix/Compiler/Pipeline.hs @@ -23,6 +23,7 @@ import Juvix.Compiler.Backend.Rust.Translation.FromReg qualified as Rust import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR import Juvix.Compiler.Casm.Data.Builtins qualified as Casm import Juvix.Compiler.Casm.Data.Result qualified as Casm +import Juvix.Compiler.Casm.Pipeline qualified as Casm import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm import Juvix.Compiler.Concrete.Data.Highlight.Input import Juvix.Compiler.Concrete.Language @@ -364,17 +365,25 @@ regToCasm = Reg.toCasm >=> return . Casm.fromReg regToCasm' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Casm.Result regToCasm' = Reg.toCasm' >=> return . Casm.fromReg -casmToCairo :: Casm.Result -> Sem r Cairo.Result -casmToCairo Casm.Result {..} = +casmToCairo :: (Member (Reader EntryPoint) r) => Casm.Result -> Sem r Cairo.Result +casmToCairo Casm.Result {..} = do + code' <- Casm.toCairo _resultCode return . Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins) - $ Cairo.fromCasm _resultCode + $ Cairo.fromCasm code' + +casmToCairo' :: Casm.Result -> Sem r Cairo.Result +casmToCairo' Casm.Result {..} = do + code' <- Casm.toCairo' _resultCode + return + . Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins) + $ Cairo.fromCasm code' regToCairo :: (Member (Reader EntryPoint) r) => Reg.InfoTable -> Sem r Cairo.Result regToCairo = regToCasm >=> casmToCairo regToCairo' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Cairo.Result -regToCairo' = regToCasm' >=> casmToCairo +regToCairo' = regToCasm' >=> casmToCairo' treeToAnoma' :: (Members '[Error JuvixError, Reader NockmaTree.CompilerOptions] r) => Tree.InfoTable -> Sem r NockmaTree.AnomaResult treeToAnoma' = Tree.toNockma >=> NockmaTree.fromTreeTable diff --git a/test/Casm/Run/Base.hs b/test/Casm/Run/Base.hs index c340841d9..365d5f38c 100644 --- a/test/Casm/Run/Base.hs +++ b/test/Casm/Run/Base.hs @@ -9,6 +9,7 @@ import Juvix.Compiler.Casm.Extra.InputInfo import Juvix.Compiler.Casm.Interpreter import Juvix.Compiler.Casm.Translation.FromSource import Juvix.Compiler.Casm.Validate +import Juvix.Compiler.Tree.Options qualified as Casm import Juvix.Data.Field import Juvix.Data.PPOutput import Juvix.Parser.Error @@ -27,14 +28,15 @@ casmRunVM labi instrs blts outputSize inputFile expectedFile step = do step "Serialize to Cairo bytecode" let res = run $ - casmToCairo - ( Casm.Result - { _resultLabelInfo = labi, - _resultCode = instrs, - _resultBuiltins = blts, - _resultOutputSize = outputSize - } - ) + runReader Casm.defaultOptions $ + casmToCairo' + ( Casm.Result + { _resultLabelInfo = labi, + _resultCode = instrs, + _resultBuiltins = blts, + _resultOutputSize = outputSize + } + ) outputFile = dirPath $(mkRelFile "out.json") encodeFile (toFilePath outputFile) res step "Run Cairo VM" diff --git a/test/Casm/Run/Positive.hs b/test/Casm/Run/Positive.hs index 8a9cfcd47..1021a5ef3 100644 --- a/test/Casm/Run/Positive.hs +++ b/test/Casm/Run/Positive.hs @@ -166,5 +166,13 @@ tests = $(mkRelDir ".") $(mkRelFile "test016.casm") $(mkRelFile "out/test016.out") - (Just $(mkRelFile "in/test016.json")) + (Just $(mkRelFile "in/test016.json")), + PosTest + "Test017: Peephole optimization" + True + True + $(mkRelDir ".") + $(mkRelFile "test017.casm") + $(mkRelFile "out/test017.out") + Nothing ] diff --git a/tests/Casm/positive/out/test017.out b/tests/Casm/positive/out/test017.out new file mode 100644 index 000000000..7f8f011eb --- /dev/null +++ b/tests/Casm/positive/out/test017.out @@ -0,0 +1 @@ +7 diff --git a/tests/Casm/positive/test017.casm b/tests/Casm/positive/test017.casm new file mode 100644 index 000000000..5245d7d73 --- /dev/null +++ b/tests/Casm/positive/test017.casm @@ -0,0 +1,29 @@ +-- peephole optimization + +call main +jmp end + +main: +jmp lab_1 +lab_1: +nop +nop +jmp lab_2 +nop +nop +lab_2: +call rel 3 +ret +nop +jmp lab_3 +lab_4: +[ap] = 7; ap++ +call rel 3 +ret +nop +[ap] = [fp - 3]; ap++ +ret +lab_3: +jmp lab_4 + +end: