diff --git a/app/Commands/Dev/Casm/Compile.hs b/app/Commands/Dev/Casm/Compile.hs index b430bc6cd..ab992be24 100644 --- a/app/Commands/Dev/Casm/Compile.hs +++ b/app/Commands/Dev/Casm/Compile.hs @@ -28,7 +28,7 @@ runCommand opts = do runReader entryPoint . runError @JuvixError . casmToCairo - $ Casm.Result labi code + $ Casm.Result labi code [] res <- getRight r liftIO $ JSON.encodeFile (toFilePath cairoFile) res where diff --git a/runtime/src/casm/stdlib.casm b/runtime/src/casm/stdlib.casm index 511c5ea3e..a6e46a45b 100644 --- a/runtime/src/casm/stdlib.casm +++ b/runtime/src/casm/stdlib.casm @@ -20,7 +20,7 @@ juvix_get_ap_reg: -- [fp - 3]: closure -- [fp - 4]: n = the number of arguments to extend with --- [fp - 4 - k]: argument n - k - 1 (reverse order!) +-- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based) juvix_extend_closure: -- copy stored args reversing them; -- to copy the stored args to the new closure @@ -95,10 +95,14 @@ juvix_extend_closure: [ap] = [fp + 15]; ap++ ret --- [fp - 3]: closure; [fp - 3 - k]: argument k to closure call +-- [fp - 3]: closure; +-- [fp - 4 - k]: argument k to closure call (0-based) +-- [fp - 4 - n]: builtin pointer, where n = number of supplied args juvix_call_closure: -- jmp rel (9 - argsnum) jmp rel [[fp - 3] + 2] + -- builtin ptr + args + [ap] = [fp - 12]; ap++ [ap] = [fp - 11]; ap++ [ap] = [fp - 10]; ap++ [ap] = [fp - 9]; ap++ diff --git a/scripts/run_cairo_vm.sh b/scripts/run_cairo_vm.sh index 5b2dd09b2..a04213421 100755 --- a/scripts/run_cairo_vm.sh +++ b/scripts/run_cairo_vm.sh @@ -2,4 +2,4 @@ BASE=`basename "$1" .json` -juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small +juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=all_cairo diff --git a/src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs b/src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs index 9ec5fd6bd..73f3d4333 100644 --- a/src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs +++ b/src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs @@ -5,26 +5,29 @@ import Juvix.Compiler.Backend.Cairo.Data.Result import Juvix.Compiler.Backend.Cairo.Language import Numeric -serialize :: [Element] -> Result -serialize elems = +serialize :: [Text] -> [Element] -> Result +serialize builtins elems = Result { _resultData = - initializeOutput + initializeBuiltins ++ map toHexText (serialize' elems) - ++ finalizeOutput + ++ finalizeBuiltins ++ finalizeJump, _resultStart = 0, - _resultEnd = length initializeOutput + length elems + length finalizeOutput, + _resultEnd = length initializeBuiltins + length elems + length finalizeBuiltins, _resultMain = 0, _resultHints = hints, - _resultBuiltins = ["output"] + _resultBuiltins = "output" : builtins } where + builtinsNum :: Natural + builtinsNum = fromIntegral (length builtins) + hints :: [(Int, Text)] hints = catMaybes $ zipWith mkHint elems [0 ..] pcShift :: Int - pcShift = length initializeOutput + pcShift = length initializeBuiltins mkHint :: Element -> Int -> Maybe (Int, Text) mkHint el pc = case el of @@ -34,21 +37,30 @@ serialize elems = toHexText :: Natural -> Text toHexText n = "0x" <> fromString (showHex n "") - initializeOutput :: [Text] - initializeOutput = + initializeBuiltins :: [Text] + initializeBuiltins = + -- ap += allBuiltinsNum [ "0x40480017fff7fff", - "0x1" + toHexText (builtinsNum + 1) ] - finalizeOutput :: [Text] - finalizeOutput = + finalizeBuiltins :: [Text] + finalizeBuiltins = + -- [[fp]] = [ap - 1] -- [output_ptr] = [ap - 1] + -- [ap] = [fp] + 1; ap++ -- output_ptr [ "0x4002800080007fff", "0x4826800180008000", "0x1" ] + ++ + -- [ap] = [ap - builtinsNum - 2]; ap++ + replicate + builtinsNum + (toHexText (0x48107ffe7fff8000 - shift builtinsNum 32)) finalizeJump :: [Text] finalizeJump = + -- jmp rel 0 [ "0x10780017fff7fff", "0x0" ] diff --git a/src/Juvix/Compiler/Casm/Data/Builtins.hs b/src/Juvix/Compiler/Casm/Data/Builtins.hs new file mode 100644 index 000000000..51ef3e65f --- /dev/null +++ b/src/Juvix/Compiler/Casm/Data/Builtins.hs @@ -0,0 +1,27 @@ +module Juvix.Compiler.Casm.Data.Builtins where + +import Juvix.Extra.Strings qualified as Str +import Juvix.Prelude + +-- The order of the builtins must correspond to the "standard" builtin order in +-- the Cairo VM implementation. See: +-- https://github.com/lambdaclass/cairo-vm/blob/main/vm/src/vm/runners/cairo_runner.rs#L257 +data Builtin + = BuiltinRangeCheck + | BuiltinEcOp + | BuiltinPoseidon + deriving stock (Show, Eq, Generic, Enum, Bounded) + +instance Hashable Builtin + +builtinsNum :: Int +builtinsNum = length (allElements @Builtin) + +builtinName :: Builtin -> Text +builtinName = \case + BuiltinRangeCheck -> Str.cairoRangeCheck + BuiltinEcOp -> Str.cairoEcOp + BuiltinPoseidon -> Str.cairoPoseidon + +builtinNames :: [Text] +builtinNames = map builtinName allElements diff --git a/src/Juvix/Compiler/Casm/Data/LabelInfoBuilder.hs b/src/Juvix/Compiler/Casm/Data/LabelInfoBuilder.hs index 4a10bb326..9e8cdff41 100644 --- a/src/Juvix/Compiler/Casm/Data/LabelInfoBuilder.hs +++ b/src/Juvix/Compiler/Casm/Data/LabelInfoBuilder.hs @@ -1,4 +1,8 @@ -module Juvix.Compiler.Casm.Data.LabelInfoBuilder where +module Juvix.Compiler.Casm.Data.LabelInfoBuilder + ( module Juvix.Compiler.Casm.Data.LabelInfo, + module Juvix.Compiler.Casm.Data.LabelInfoBuilder, + ) +where import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Casm.Data.LabelInfo diff --git a/src/Juvix/Compiler/Casm/Data/Result.hs b/src/Juvix/Compiler/Casm/Data/Result.hs index 3310e521d..4d74e569b 100644 --- a/src/Juvix/Compiler/Casm/Data/Result.hs +++ b/src/Juvix/Compiler/Casm/Data/Result.hs @@ -1,11 +1,13 @@ module Juvix.Compiler.Casm.Data.Result where +import Juvix.Compiler.Casm.Data.Builtins import Juvix.Compiler.Casm.Data.LabelInfo import Juvix.Compiler.Casm.Language data Result = Result { _resultLabelInfo :: LabelInfo, - _resultCode :: [Instruction] + _resultCode :: [Instruction], + _resultBuiltins :: [Builtin] } makeLenses ''Result diff --git a/src/Juvix/Compiler/Casm/Interpreter.hs b/src/Juvix/Compiler/Casm/Interpreter.hs index 36fb56a63..2154f567e 100644 --- a/src/Juvix/Compiler/Casm/Interpreter.hs +++ b/src/Juvix/Compiler/Casm/Interpreter.hs @@ -11,6 +11,7 @@ import Data.HashMap.Strict qualified as HashMap import Data.Vector qualified as Vec import Data.Vector.Mutable qualified as MV import GHC.IO qualified as GHC +import Juvix.Compiler.Casm.Data.Builtins import Juvix.Compiler.Casm.Data.InputInfo import Juvix.Compiler.Casm.Data.LabelInfo import Juvix.Compiler.Casm.Error @@ -39,7 +40,9 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode goCode :: ST s FField goCode = do mem <- MV.replicate initialMemSize Nothing - go 0 0 0 mem + forM_ [0 .. builtinsNum] $ \k -> + MV.write mem k (Just (fieldFromInteger cairoFieldSize 0)) + go 0 (builtinsNum + 1) 0 mem go :: Address -> diff --git a/src/Juvix/Compiler/Casm/Translation/FromCairo.hs b/src/Juvix/Compiler/Casm/Translation/FromCairo.hs index 5375fa675..b00fdbb15 100644 --- a/src/Juvix/Compiler/Casm/Translation/FromCairo.hs +++ b/src/Juvix/Compiler/Casm/Translation/FromCairo.hs @@ -5,7 +5,12 @@ import Juvix.Compiler.Casm.Data.Result import Juvix.Compiler.Casm.Language fromCairo :: [Cairo.Element] -> Result -fromCairo elems0 = Result mempty (go 0 [] elems0) +fromCairo elems0 = + Result + { _resultLabelInfo = mempty, + _resultCode = go 0 [] elems0, + _resultBuiltins = mempty + } where errorMsg :: Address -> Text -> a errorMsg addr msg = error ("error at address " <> show addr <> ": " <> msg) diff --git a/src/Juvix/Compiler/Casm/Translation/FromReg.hs b/src/Juvix/Compiler/Casm/Translation/FromReg.hs index a1aad774c..cdfd321c0 100644 --- a/src/Juvix/Compiler/Casm/Translation/FromReg.hs +++ b/src/Juvix/Compiler/Casm/Translation/FromReg.hs @@ -3,6 +3,7 @@ module Juvix.Compiler.Casm.Translation.FromReg where import Data.HashMap.Strict qualified as HashMap import Data.HashSet qualified as HashSet import Data.Text qualified as Text +import Juvix.Compiler.Casm.Data.Builtins import Juvix.Compiler.Casm.Data.LabelInfoBuilder import Juvix.Compiler.Casm.Data.Limits import Juvix.Compiler.Casm.Data.Result @@ -17,26 +18,56 @@ import Juvix.Compiler.Tree.Evaluator.Builtins qualified as Reg import Juvix.Data.Field fromReg :: Reg.InfoTable -> Result -fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do +fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do + let startAddr :: Address = 2 + startSym <- freshSymbol + endSym <- freshSymbol + let startName :: Text = "__juvix_start" + startLab = LabelRef startSym (Just startName) + endName :: Text = "__juvix_end" + endLab = LabelRef endSym (Just endName) + registerLabelName startSym startName + registerLabelAddress startSym startAddr let mainSym = fromJust $ tab ^. Reg.infoMainFunction mainInfo = fromJust (HashMap.lookup mainSym (tab ^. Reg.infoFunctions)) mainName = mainInfo ^. Reg.functionName mainArgs = getInputArgs (mainInfo ^. Reg.functionArgsNum) (mainInfo ^. Reg.functionArgNames) - initialOffset = length mainArgs + 2 - (blts, binstrs) <- addStdlibBuiltins initialOffset + bnum = toOffset builtinsNum + callStartInstr = mkCallRel (Lab startLab) + initBuiltinsInstr = mkAssignAp (Binop $ BinopValue FieldAdd (MemRef Fp (-2)) (Imm 1)) + callMainInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName)) + jmpEndInstr = mkJumpRel (Val $ Lab endLab) + margs = reverse $ map (Hint . HintInput) mainArgs + -- [ap] = [[ap - 2 - k] + k]; ap++ + bltsRet = map (\k -> mkAssignAp (Load $ LoadValue (MemRef Ap (-2 - k)) k)) [0 .. bnum - 1] + resRetInstr = mkAssignAp (Val $ Ref $ MemRef Ap (-bnum - 1)) + pinstrs = + callStartInstr + : jmpEndInstr + : Label startLab + : initBuiltinsInstr + : margs + ++ callMainInstr + : bltsRet + ++ [resRetInstr, Return] + (blts, binstrs) <- addStdlibBuiltins (length pinstrs) let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs) - endSym <- freshSymbol - let endName :: Text = "__juvix_end" - endLab = LabelRef endSym (Just endName) - (addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (initialOffset + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions) - eassert (addr == length instrs + length cinstrs + length binstrs + initialOffset) + (addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (length pinstrs + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions) + eassert (addr == length instrs + length cinstrs + length binstrs + length pinstrs) registerLabelName endSym endName registerLabelAddress endSym addr - let callInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName)) - jmpInstr = mkJumpRel (Val $ Lab endLab) - margs = reverse $ map (Hint . HintInput) mainArgs - return $ margs ++ callInstr : jmpInstr : binstrs ++ cinstrs ++ instrs ++ [Label endLab] + return $ + ( allElements, + pinstrs + ++ binstrs + ++ cinstrs + ++ instrs + ++ [Label endLab] + ) where + mkResult :: (LabelInfo, ([Builtin], Code)) -> Result + mkResult (labi, (blts, code)) = Result labi code blts + info :: Reg.ExtraInfo info = Reg.computeExtraInfo tab @@ -72,6 +103,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS ("fp", "__fp__") ] + argsOffset :: Int + argsOffset = 3 + goFun :: forall r. (Member LabelInfoBuilder r) => StdlibBuiltins -> LabelRef -> (Address, [[Instruction]]) -> Reg.FunctionInfo -> Sem r (Address, [[Instruction]]) goFun blts failLab (addr0, acc) funInfo = do let sym = funInfo ^. Reg.functionSymbol @@ -85,10 +119,10 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS n = funInfo ^. Reg.functionArgsNum let vars = HashMap.fromList $ - map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -3 - k)) [0 .. n - 1] + map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -argsOffset - k)) [0 .. n - 1] instrs <- fmap fst - . runCasmBuilder addr1 vars + . runCasmBuilder addr1 vars (-argsOffset - n) . runOutputList $ goBlock blts failLab mempty Nothing block return (addr1 + length instrs, (pre ++ instrs) : acc) @@ -119,7 +153,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS Nothing -> do eassert (isJust mout) eassert (HashSet.member (fromJust mout) liveVars0) - goCallBlock Nothing liveVars0 + goCallBlock False Nothing liveVars0 where output'' :: Instruction -> Sem r () output'' i = do @@ -131,14 +165,22 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS output'' i incAP apOff - goCallBlock :: Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r () - goCallBlock outVar liveVars = do + goCallBlock :: Bool -> Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r () + goCallBlock updatedBuiltins outVar liveVars = do let liveVars' = toList (maybe liveVars (flip HashSet.delete liveVars) outVar) n = length liveVars' + bltOff = + if + | updatedBuiltins -> + -argsOffset - n - fromEnum (isJust outVar) + | otherwise -> + -argsOffset - n vars = HashMap.fromList $ - maybe [] (\var -> [(var, -3 - n)]) outVar - ++ zipWithExact (\var k -> (var, -3 - k)) liveVars' [0 .. n - 1] + maybe [] (\var -> [(var, -argsOffset - n - if updatedBuiltins then 0 else 1)]) outVar + ++ zipWithExact (\var k -> (var, -argsOffset - k)) liveVars' [0 .. n - 1] + unless updatedBuiltins $ + goAssignApBuiltins mapM_ (mkMemRef >=> goAssignAp . Val . Ref) (reverse liveVars') output'' (mkCallRel $ Imm 3) output'' Return @@ -148,11 +190,13 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS output'' Nop setAP 0 setVars vars + setBuiltinOffset bltOff - goLocalBlock :: Int -> HashMap Reg.VarRef Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r () - goLocalBlock ap0 vars liveVars mout' block = do + goLocalBlock :: Int -> HashMap Reg.VarRef Int -> Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r () + goLocalBlock ap0 vars bltOff liveVars mout' block = do setAP ap0 setVars vars + setBuiltinOffset bltOff goBlock blts failLab liveVars mout' block ---------------------------------------------------------------------- @@ -179,6 +223,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS v <- lookupVar' vr return $ MemRef Fp (toOffset v) + mkBuiltinRef :: Sem r MemRef + mkBuiltinRef = do + off <- getBuiltinOffset + return $ MemRef Fp (toOffset off) + mkRValue :: Reg.Value -> Sem r RValue mkRValue = \case Reg.ValConst c -> return $ Val $ Imm $ mkConst c @@ -216,6 +265,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS goAssignApValue :: Reg.Value -> Sem r () goAssignApValue v = mkRValue v >>= goAssignAp + goAssignApBuiltins :: Sem r () + goAssignApBuiltins = mkBuiltinRef >>= goAssignAp . Val . Ref + goValue :: Reg.Value -> Sem r Value goValue = \case Reg.ValConst c -> return $ Imm $ mkConst c @@ -414,16 +466,18 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS val <- mkMemRef _instrExtendClosureValue goAssignAp (Val $ Ref val) output'' $ mkCallRel $ Lab $ LabelRef (blts ^. stdlibExtendClosure) (Just (blts ^. stdlibExtendClosureName)) - goCallBlock (Just _instrExtendClosureResult) liveVars + goCallBlock False (Just _instrExtendClosureResult) liveVars goCall' :: Reg.CallType -> [Reg.Value] -> Sem r () goCall' ct args = case ct of Reg.CallFun sym -> do + goAssignApBuiltins mapM_ goAssignApValue (reverse args) output'' $ mkCallRel $ Lab $ LabelRef sym (Just funName) where funName = quoteName (Reg.lookupFunInfo tab sym ^. Reg.functionName) Reg.CallClosure cl -> do + goAssignApBuiltins mapM_ goAssignApValue (reverse args) r <- mkMemRef cl goAssignAp (Val $ Ref r) @@ -432,7 +486,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS goCall :: HashSet Reg.VarRef -> Reg.InstrCall -> Sem r () goCall liveVars Reg.InstrCall {..} = do goCall' _instrCallType _instrCallArgs - goCallBlock (Just _instrCallResult) liveVars + goCallBlock True (Just _instrCallResult) liveVars -- There is no way to make "proper" tail calls in Cairo, because -- the only way to set the `fp` register is via the `call` instruction. @@ -444,6 +498,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS goReturn :: Reg.InstrReturn -> Sem r () goReturn Reg.InstrReturn {..} = do + goAssignApBuiltins goAssignApValue _instrReturnValue output'' Return @@ -462,14 +517,15 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS output'' $ mkJumpIf (Lab labFalse) r ap0 <- getAP vars <- getVars - goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchTrue + bltOff <- getBuiltinOffset + goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchTrue -- _instrBranchOutVar is Nothing iff the branch returns when (isJust _instrBranchOutVar) $ output'' (mkJumpRel (Val $ Lab labEnd)) addrFalse <- getPC registerLabelAddress symFalse addrFalse output'' $ Label labFalse - goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchFalse + goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchFalse addrEnd <- getPC registerLabelAddress symEnd addrEnd output'' $ Label labEnd @@ -501,10 +557,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS mapM_ output'' jmps' ap0 <- getAP vars <- getVars - mapM_ (goCaseBranch ap0 vars symMap labEnd) _instrCaseBranches + bltOff <- getBuiltinOffset + mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches mapM_ (goDefaultLabel symMap) defaultTags whenJust _instrCaseDefault $ - goLocalBlock ap0 vars liveVars _instrCaseOutVar + goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar addrEnd <- getPC registerLabelAddress symEnd addrEnd output'' $ Label labEnd @@ -513,14 +570,14 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS ctrTags = HashSet.fromList $ map (^. Reg.caseBranchTag) _instrCaseBranches defaultTags = filter (not . flip HashSet.member ctrTags) tags - goCaseBranch :: Int -> HashMap Reg.VarRef Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r () - goCaseBranch ap0 vars symMap labEnd Reg.CaseBranch {..} = do + goCaseBranch :: Int -> HashMap Reg.VarRef Int -> Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r () + goCaseBranch ap0 vars bltOff symMap labEnd Reg.CaseBranch {..} = do let sym = fromJust $ HashMap.lookup _caseBranchTag symMap lab = LabelRef sym Nothing addr <- getPC registerLabelAddress sym addr output'' $ Label lab - goLocalBlock ap0 vars liveVars _instrCaseOutVar _caseBranchCode + goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar _caseBranchCode -- _instrCaseOutVar is Nothing iff the branch returns when (isJust _instrCaseOutVar) $ output'' (mkJumpRel (Val $ Lab labEnd)) diff --git a/src/Juvix/Compiler/Casm/Translation/FromReg/CasmBuilder.hs b/src/Juvix/Compiler/Casm/Translation/FromReg/CasmBuilder.hs index 1caec9ce4..2b7dd1d98 100644 --- a/src/Juvix/Compiler/Casm/Translation/FromReg/CasmBuilder.hs +++ b/src/Juvix/Compiler/Casm/Translation/FromReg/CasmBuilder.hs @@ -14,27 +14,31 @@ data CasmBuilder :: Effect where LookupVar :: VarRef -> CasmBuilder m (Maybe Int) GetVars :: CasmBuilder m (HashMap VarRef Int) SetVars :: HashMap VarRef Int -> CasmBuilder m () + GetBuiltinOffset :: CasmBuilder m Int + SetBuiltinOffset :: Int -> CasmBuilder m () makeSem ''CasmBuilder data BuilderState = BuilderState { _statePC :: Address, _stateAP :: Int, - _stateVarMap :: HashMap VarRef Int + _stateVarMap :: HashMap VarRef Int, + _stateBuiltinOff :: Int } makeLenses ''BuilderState -mkBuilderState :: Address -> HashMap VarRef Int -> BuilderState -mkBuilderState addr vars = +mkBuilderState :: Address -> HashMap VarRef Int -> Int -> BuilderState +mkBuilderState addr vars bltOff = BuilderState { _statePC = addr, _stateAP = 0, - _stateVarMap = vars + _stateVarMap = vars, + _stateBuiltinOff = bltOff } -runCasmBuilder :: Address -> HashMap VarRef Int -> Sem (CasmBuilder ': r) a -> Sem r a -runCasmBuilder addr vars = fmap snd . runCasmBuilder' (mkBuilderState addr vars) +runCasmBuilder :: Address -> HashMap VarRef Int -> Int -> Sem (CasmBuilder ': r) a -> Sem r a +runCasmBuilder addr vars bltOff = fmap snd . runCasmBuilder' (mkBuilderState addr vars bltOff) runCasmBuilder' :: BuilderState -> Sem (CasmBuilder ': r) a -> Sem r (BuilderState, a) runCasmBuilder' bs = reinterpret (runState bs) interp @@ -60,6 +64,10 @@ runCasmBuilder' bs = reinterpret (runState bs) interp gets (^. stateVarMap) SetVars vars -> do modify' (set stateVarMap vars) + GetBuiltinOffset -> do + gets (^. stateBuiltinOff) + SetBuiltinOffset bltOff -> do + modify' (set stateBuiltinOff bltOff) lookupVar' :: (Member CasmBuilder r) => VarRef -> Sem r Int lookupVar' = lookupVar >=> return . fromJust diff --git a/src/Juvix/Compiler/Pipeline.hs b/src/Juvix/Compiler/Pipeline.hs index 2d35aeab5..2eb4017d6 100644 --- a/src/Juvix/Compiler/Pipeline.hs +++ b/src/Juvix/Compiler/Pipeline.hs @@ -17,6 +17,7 @@ import Juvix.Compiler.Backend.C qualified as C import Juvix.Compiler.Backend.Cairo qualified as Cairo import Juvix.Compiler.Backend.Geb qualified as Geb import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR +import Juvix.Compiler.Casm.Data.Builtins qualified as Casm import Juvix.Compiler.Casm.Data.Result qualified as Casm import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm import Juvix.Compiler.Concrete.Data.Highlight.Input @@ -278,7 +279,10 @@ regToCasm :: Reg.InfoTable -> Sem r Casm.Result regToCasm = Reg.toCasm >=> return . Casm.fromReg casmToCairo :: Casm.Result -> Sem r Cairo.Result -casmToCairo Casm.Result {..} = return $ Cairo.serialize $ Cairo.fromCasm _resultCode +casmToCairo Casm.Result {..} = + return + . Cairo.serialize (map Casm.builtinName _resultBuiltins) + $ Cairo.fromCasm _resultCode regToCairo :: Reg.InfoTable -> Sem r Cairo.Result regToCairo = regToCasm >=> casmToCairo diff --git a/src/Juvix/Extra/Strings.hs b/src/Juvix/Extra/Strings.hs index 4c9a2b3d7..dc76c4f90 100644 --- a/src/Juvix/Extra/Strings.hs +++ b/src/Juvix/Extra/Strings.hs @@ -991,3 +991,12 @@ functionsPlaceholder = "functionsLibrary_placeholder" theFunctionsLibrary :: (IsString s) => s theFunctionsLibrary = "the_functionsLibrary" + +cairoRangeCheck :: (IsString s) => s +cairoRangeCheck = "range_check" + +cairoPoseidon :: (IsString s) => s +cairoPoseidon = "poseidon" + +cairoEcOp :: (IsString s) => s +cairoEcOp = "ec_op" diff --git a/test/Casm/Compilation/Base.hs b/test/Casm/Compilation/Base.hs index 14f2bce97..6fc13a684 100644 --- a/test/Casm/Compilation/Base.hs +++ b/test/Casm/Compilation/Base.hs @@ -42,4 +42,4 @@ compileAssertionEntry adjustEntry root' bRunVM optLevel mainFile expectedFile st step "Pretty print" writeFileEnsureLn tmpFile (toPlainText $ ppProgram _resultCode) ) - casmRunAssertion' bRunVM _resultLabelInfo _resultCode Nothing expectedFile step + casmRunAssertion' bRunVM _resultLabelInfo _resultCode _resultBuiltins Nothing expectedFile step diff --git a/test/Casm/Run/Base.hs b/test/Casm/Run/Base.hs index 948c2556b..16f76ed10 100644 --- a/test/Casm/Run/Base.hs +++ b/test/Casm/Run/Base.hs @@ -2,6 +2,7 @@ module Casm.Run.Base where import Base import Data.Aeson +import Juvix.Compiler.Casm.Data.Builtins import Juvix.Compiler.Casm.Data.Result qualified as Casm import Juvix.Compiler.Casm.Error import Juvix.Compiler.Casm.Extra.InputInfo @@ -18,14 +19,14 @@ casmRunVM' dirPath outputFile inputFile = do let args = maybe [] (\f -> ["--program_input", toFilePath f]) inputFile R.readProcessCwd (toFilePath dirPath) "run_cairo_vm.sh" (toFilePath outputFile : args) "" -casmRunVM :: LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion -casmRunVM labi instrs inputFile expectedFile step = do +casmRunVM :: LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion +casmRunVM labi instrs blts inputFile expectedFile step = do step "Check run_cairo_vm.sh is on path" assertCmdExists $(mkRelFile "run_cairo_vm.sh") withTempDir' ( \dirPath -> do step "Serialize to Cairo bytecode" - let res = run $ casmToCairo (Casm.Result labi instrs) + let res = run $ casmToCairo (Casm.Result labi instrs blts) outputFile = dirPath $(mkRelFile "out.json") encodeFile (toFilePath outputFile) res step "Run Cairo VM" @@ -35,8 +36,8 @@ casmRunVM labi instrs inputFile expectedFile step = do assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected ) -casmRunAssertion' :: Bool -> LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion -casmRunAssertion' bRunVM labi instrs inputFile expectedFile step = +casmRunAssertion' :: Bool -> LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion +casmRunAssertion' bRunVM labi instrs blts inputFile expectedFile step = case validate labi instrs of Left err -> do assertFailure (prettyString err) @@ -60,7 +61,7 @@ casmRunAssertion' bRunVM labi instrs inputFile expectedFile step = assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected ) when bRunVM $ - casmRunVM labi instrs inputFile expectedFile step + casmRunVM labi instrs blts inputFile expectedFile step casmRunAssertion :: Bool -> Path Abs File -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion casmRunAssertion bRunVM mainFile inputFile expectedFile step = do @@ -68,7 +69,7 @@ casmRunAssertion bRunVM mainFile inputFile expectedFile step = do r <- parseFile mainFile case r of Left err -> assertFailure (prettyString err) - Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs inputFile expectedFile step + Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs [] inputFile expectedFile step casmRunErrorAssertion :: Path Abs File -> (String -> IO ()) -> Assertion casmRunErrorAssertion mainFile step = do