1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-22 23:23:06 +03:00

Support for Cairo builtins (#2718)

This PR implements generic support for Cairo VM builtins. The calling
convention in the generated CASM code is changed to allow for passing
around the builtin pointers. Appropriate builtin initialization and
finalization code is added. Support for specific builtins (e.g. Poseidon
hash, range check, Elliptic Curve operation) still needs to be
implemented in separate PRs.

* Closes #2683
This commit is contained in:
Łukasz Czajka 2024-04-16 19:01:30 +02:00 committed by GitHub
parent 65176a333d
commit ad76c7a583
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 201 additions and 65 deletions

View File

@ -28,7 +28,7 @@ runCommand opts = do
runReader entryPoint runReader entryPoint
. runError @JuvixError . runError @JuvixError
. casmToCairo . casmToCairo
$ Casm.Result labi code $ Casm.Result labi code []
res <- getRight r res <- getRight r
liftIO $ JSON.encodeFile (toFilePath cairoFile) res liftIO $ JSON.encodeFile (toFilePath cairoFile) res
where where

View File

@ -20,7 +20,7 @@ juvix_get_ap_reg:
-- [fp - 3]: closure -- [fp - 3]: closure
-- [fp - 4]: n = the number of arguments to extend with -- [fp - 4]: n = the number of arguments to extend with
-- [fp - 4 - k]: argument n - k - 1 (reverse order!) -- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based)
juvix_extend_closure: juvix_extend_closure:
-- copy stored args reversing them; -- copy stored args reversing them;
-- to copy the stored args to the new closure -- to copy the stored args to the new closure
@ -95,10 +95,14 @@ juvix_extend_closure:
[ap] = [fp + 15]; ap++ [ap] = [fp + 15]; ap++
ret ret
-- [fp - 3]: closure; [fp - 3 - k]: argument k to closure call -- [fp - 3]: closure;
-- [fp - 4 - k]: argument k to closure call (0-based)
-- [fp - 4 - n]: builtin pointer, where n = number of supplied args
juvix_call_closure: juvix_call_closure:
-- jmp rel (9 - argsnum) -- jmp rel (9 - argsnum)
jmp rel [[fp - 3] + 2] jmp rel [[fp - 3] + 2]
-- builtin ptr + args
[ap] = [fp - 12]; ap++
[ap] = [fp - 11]; ap++ [ap] = [fp - 11]; ap++
[ap] = [fp - 10]; ap++ [ap] = [fp - 10]; ap++
[ap] = [fp - 9]; ap++ [ap] = [fp - 9]; ap++

View File

@ -2,4 +2,4 @@
BASE=`basename "$1" .json` BASE=`basename "$1" .json`
juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=all_cairo

View File

@ -5,26 +5,29 @@ import Juvix.Compiler.Backend.Cairo.Data.Result
import Juvix.Compiler.Backend.Cairo.Language import Juvix.Compiler.Backend.Cairo.Language
import Numeric import Numeric
serialize :: [Element] -> Result serialize :: [Text] -> [Element] -> Result
serialize elems = serialize builtins elems =
Result Result
{ _resultData = { _resultData =
initializeOutput initializeBuiltins
++ map toHexText (serialize' elems) ++ map toHexText (serialize' elems)
++ finalizeOutput ++ finalizeBuiltins
++ finalizeJump, ++ finalizeJump,
_resultStart = 0, _resultStart = 0,
_resultEnd = length initializeOutput + length elems + length finalizeOutput, _resultEnd = length initializeBuiltins + length elems + length finalizeBuiltins,
_resultMain = 0, _resultMain = 0,
_resultHints = hints, _resultHints = hints,
_resultBuiltins = ["output"] _resultBuiltins = "output" : builtins
} }
where where
builtinsNum :: Natural
builtinsNum = fromIntegral (length builtins)
hints :: [(Int, Text)] hints :: [(Int, Text)]
hints = catMaybes $ zipWith mkHint elems [0 ..] hints = catMaybes $ zipWith mkHint elems [0 ..]
pcShift :: Int pcShift :: Int
pcShift = length initializeOutput pcShift = length initializeBuiltins
mkHint :: Element -> Int -> Maybe (Int, Text) mkHint :: Element -> Int -> Maybe (Int, Text)
mkHint el pc = case el of mkHint el pc = case el of
@ -34,21 +37,30 @@ serialize elems =
toHexText :: Natural -> Text toHexText :: Natural -> Text
toHexText n = "0x" <> fromString (showHex n "") toHexText n = "0x" <> fromString (showHex n "")
initializeOutput :: [Text] initializeBuiltins :: [Text]
initializeOutput = initializeBuiltins =
-- ap += allBuiltinsNum
[ "0x40480017fff7fff", [ "0x40480017fff7fff",
"0x1" toHexText (builtinsNum + 1)
] ]
finalizeOutput :: [Text] finalizeBuiltins :: [Text]
finalizeOutput = finalizeBuiltins =
-- [[fp]] = [ap - 1] -- [output_ptr] = [ap - 1]
-- [ap] = [fp] + 1; ap++ -- output_ptr
[ "0x4002800080007fff", [ "0x4002800080007fff",
"0x4826800180008000", "0x4826800180008000",
"0x1" "0x1"
] ]
++
-- [ap] = [ap - builtinsNum - 2]; ap++
replicate
builtinsNum
(toHexText (0x48107ffe7fff8000 - shift builtinsNum 32))
finalizeJump :: [Text] finalizeJump :: [Text]
finalizeJump = finalizeJump =
-- jmp rel 0
[ "0x10780017fff7fff", [ "0x10780017fff7fff",
"0x0" "0x0"
] ]

View File

@ -0,0 +1,27 @@
module Juvix.Compiler.Casm.Data.Builtins where
import Juvix.Extra.Strings qualified as Str
import Juvix.Prelude
-- The order of the builtins must correspond to the "standard" builtin order in
-- the Cairo VM implementation. See:
-- https://github.com/lambdaclass/cairo-vm/blob/main/vm/src/vm/runners/cairo_runner.rs#L257
data Builtin
= BuiltinRangeCheck
| BuiltinEcOp
| BuiltinPoseidon
deriving stock (Show, Eq, Generic, Enum, Bounded)
instance Hashable Builtin
builtinsNum :: Int
builtinsNum = length (allElements @Builtin)
builtinName :: Builtin -> Text
builtinName = \case
BuiltinRangeCheck -> Str.cairoRangeCheck
BuiltinEcOp -> Str.cairoEcOp
BuiltinPoseidon -> Str.cairoPoseidon
builtinNames :: [Text]
builtinNames = map builtinName allElements

View File

@ -1,4 +1,8 @@
module Juvix.Compiler.Casm.Data.LabelInfoBuilder where module Juvix.Compiler.Casm.Data.LabelInfoBuilder
( module Juvix.Compiler.Casm.Data.LabelInfo,
module Juvix.Compiler.Casm.Data.LabelInfoBuilder,
)
where
import Data.HashMap.Strict qualified as HashMap import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Casm.Data.LabelInfo import Juvix.Compiler.Casm.Data.LabelInfo

View File

@ -1,11 +1,13 @@
module Juvix.Compiler.Casm.Data.Result where module Juvix.Compiler.Casm.Data.Result where
import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.LabelInfo import Juvix.Compiler.Casm.Data.LabelInfo
import Juvix.Compiler.Casm.Language import Juvix.Compiler.Casm.Language
data Result = Result data Result = Result
{ _resultLabelInfo :: LabelInfo, { _resultLabelInfo :: LabelInfo,
_resultCode :: [Instruction] _resultCode :: [Instruction],
_resultBuiltins :: [Builtin]
} }
makeLenses ''Result makeLenses ''Result

View File

@ -11,6 +11,7 @@ import Data.HashMap.Strict qualified as HashMap
import Data.Vector qualified as Vec import Data.Vector qualified as Vec
import Data.Vector.Mutable qualified as MV import Data.Vector.Mutable qualified as MV
import GHC.IO qualified as GHC import GHC.IO qualified as GHC
import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.InputInfo import Juvix.Compiler.Casm.Data.InputInfo
import Juvix.Compiler.Casm.Data.LabelInfo import Juvix.Compiler.Casm.Data.LabelInfo
import Juvix.Compiler.Casm.Error import Juvix.Compiler.Casm.Error
@ -39,7 +40,9 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
goCode :: ST s FField goCode :: ST s FField
goCode = do goCode = do
mem <- MV.replicate initialMemSize Nothing mem <- MV.replicate initialMemSize Nothing
go 0 0 0 mem forM_ [0 .. builtinsNum] $ \k ->
MV.write mem k (Just (fieldFromInteger cairoFieldSize 0))
go 0 (builtinsNum + 1) 0 mem
go :: go ::
Address -> Address ->

View File

@ -5,7 +5,12 @@ import Juvix.Compiler.Casm.Data.Result
import Juvix.Compiler.Casm.Language import Juvix.Compiler.Casm.Language
fromCairo :: [Cairo.Element] -> Result fromCairo :: [Cairo.Element] -> Result
fromCairo elems0 = Result mempty (go 0 [] elems0) fromCairo elems0 =
Result
{ _resultLabelInfo = mempty,
_resultCode = go 0 [] elems0,
_resultBuiltins = mempty
}
where where
errorMsg :: Address -> Text -> a errorMsg :: Address -> Text -> a
errorMsg addr msg = error ("error at address " <> show addr <> ": " <> msg) errorMsg addr msg = error ("error at address " <> show addr <> ": " <> msg)

View File

@ -3,6 +3,7 @@ module Juvix.Compiler.Casm.Translation.FromReg where
import Data.HashMap.Strict qualified as HashMap import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet import Data.HashSet qualified as HashSet
import Data.Text qualified as Text import Data.Text qualified as Text
import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.LabelInfoBuilder import Juvix.Compiler.Casm.Data.LabelInfoBuilder
import Juvix.Compiler.Casm.Data.Limits import Juvix.Compiler.Casm.Data.Limits
import Juvix.Compiler.Casm.Data.Result import Juvix.Compiler.Casm.Data.Result
@ -17,26 +18,56 @@ import Juvix.Compiler.Tree.Evaluator.Builtins qualified as Reg
import Juvix.Data.Field import Juvix.Data.Field
fromReg :: Reg.InfoTable -> Result fromReg :: Reg.InfoTable -> Result
fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do
let startAddr :: Address = 2
startSym <- freshSymbol
endSym <- freshSymbol
let startName :: Text = "__juvix_start"
startLab = LabelRef startSym (Just startName)
endName :: Text = "__juvix_end"
endLab = LabelRef endSym (Just endName)
registerLabelName startSym startName
registerLabelAddress startSym startAddr
let mainSym = fromJust $ tab ^. Reg.infoMainFunction let mainSym = fromJust $ tab ^. Reg.infoMainFunction
mainInfo = fromJust (HashMap.lookup mainSym (tab ^. Reg.infoFunctions)) mainInfo = fromJust (HashMap.lookup mainSym (tab ^. Reg.infoFunctions))
mainName = mainInfo ^. Reg.functionName mainName = mainInfo ^. Reg.functionName
mainArgs = getInputArgs (mainInfo ^. Reg.functionArgsNum) (mainInfo ^. Reg.functionArgNames) mainArgs = getInputArgs (mainInfo ^. Reg.functionArgsNum) (mainInfo ^. Reg.functionArgNames)
initialOffset = length mainArgs + 2 bnum = toOffset builtinsNum
(blts, binstrs) <- addStdlibBuiltins initialOffset callStartInstr = mkCallRel (Lab startLab)
initBuiltinsInstr = mkAssignAp (Binop $ BinopValue FieldAdd (MemRef Fp (-2)) (Imm 1))
callMainInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName))
jmpEndInstr = mkJumpRel (Val $ Lab endLab)
margs = reverse $ map (Hint . HintInput) mainArgs
-- [ap] = [[ap - 2 - k] + k]; ap++
bltsRet = map (\k -> mkAssignAp (Load $ LoadValue (MemRef Ap (-2 - k)) k)) [0 .. bnum - 1]
resRetInstr = mkAssignAp (Val $ Ref $ MemRef Ap (-bnum - 1))
pinstrs =
callStartInstr
: jmpEndInstr
: Label startLab
: initBuiltinsInstr
: margs
++ callMainInstr
: bltsRet
++ [resRetInstr, Return]
(blts, binstrs) <- addStdlibBuiltins (length pinstrs)
let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs) let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs)
endSym <- freshSymbol (addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (length pinstrs + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
let endName :: Text = "__juvix_end" eassert (addr == length instrs + length cinstrs + length binstrs + length pinstrs)
endLab = LabelRef endSym (Just endName)
(addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (initialOffset + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
eassert (addr == length instrs + length cinstrs + length binstrs + initialOffset)
registerLabelName endSym endName registerLabelName endSym endName
registerLabelAddress endSym addr registerLabelAddress endSym addr
let callInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName)) return $
jmpInstr = mkJumpRel (Val $ Lab endLab) ( allElements,
margs = reverse $ map (Hint . HintInput) mainArgs pinstrs
return $ margs ++ callInstr : jmpInstr : binstrs ++ cinstrs ++ instrs ++ [Label endLab] ++ binstrs
++ cinstrs
++ instrs
++ [Label endLab]
)
where where
mkResult :: (LabelInfo, ([Builtin], Code)) -> Result
mkResult (labi, (blts, code)) = Result labi code blts
info :: Reg.ExtraInfo info :: Reg.ExtraInfo
info = Reg.computeExtraInfo tab info = Reg.computeExtraInfo tab
@ -72,6 +103,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
("fp", "__fp__") ("fp", "__fp__")
] ]
argsOffset :: Int
argsOffset = 3
goFun :: forall r. (Member LabelInfoBuilder r) => StdlibBuiltins -> LabelRef -> (Address, [[Instruction]]) -> Reg.FunctionInfo -> Sem r (Address, [[Instruction]]) goFun :: forall r. (Member LabelInfoBuilder r) => StdlibBuiltins -> LabelRef -> (Address, [[Instruction]]) -> Reg.FunctionInfo -> Sem r (Address, [[Instruction]])
goFun blts failLab (addr0, acc) funInfo = do goFun blts failLab (addr0, acc) funInfo = do
let sym = funInfo ^. Reg.functionSymbol let sym = funInfo ^. Reg.functionSymbol
@ -85,10 +119,10 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
n = funInfo ^. Reg.functionArgsNum n = funInfo ^. Reg.functionArgsNum
let vars = let vars =
HashMap.fromList $ HashMap.fromList $
map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -3 - k)) [0 .. n - 1] map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -argsOffset - k)) [0 .. n - 1]
instrs <- instrs <-
fmap fst fmap fst
. runCasmBuilder addr1 vars . runCasmBuilder addr1 vars (-argsOffset - n)
. runOutputList . runOutputList
$ goBlock blts failLab mempty Nothing block $ goBlock blts failLab mempty Nothing block
return (addr1 + length instrs, (pre ++ instrs) : acc) return (addr1 + length instrs, (pre ++ instrs) : acc)
@ -119,7 +153,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
Nothing -> do Nothing -> do
eassert (isJust mout) eassert (isJust mout)
eassert (HashSet.member (fromJust mout) liveVars0) eassert (HashSet.member (fromJust mout) liveVars0)
goCallBlock Nothing liveVars0 goCallBlock False Nothing liveVars0
where where
output'' :: Instruction -> Sem r () output'' :: Instruction -> Sem r ()
output'' i = do output'' i = do
@ -131,14 +165,22 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
output'' i output'' i
incAP apOff incAP apOff
goCallBlock :: Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r () goCallBlock :: Bool -> Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r ()
goCallBlock outVar liveVars = do goCallBlock updatedBuiltins outVar liveVars = do
let liveVars' = toList (maybe liveVars (flip HashSet.delete liveVars) outVar) let liveVars' = toList (maybe liveVars (flip HashSet.delete liveVars) outVar)
n = length liveVars' n = length liveVars'
bltOff =
if
| updatedBuiltins ->
-argsOffset - n - fromEnum (isJust outVar)
| otherwise ->
-argsOffset - n
vars = vars =
HashMap.fromList $ HashMap.fromList $
maybe [] (\var -> [(var, -3 - n)]) outVar maybe [] (\var -> [(var, -argsOffset - n - if updatedBuiltins then 0 else 1)]) outVar
++ zipWithExact (\var k -> (var, -3 - k)) liveVars' [0 .. n - 1] ++ zipWithExact (\var k -> (var, -argsOffset - k)) liveVars' [0 .. n - 1]
unless updatedBuiltins $
goAssignApBuiltins
mapM_ (mkMemRef >=> goAssignAp . Val . Ref) (reverse liveVars') mapM_ (mkMemRef >=> goAssignAp . Val . Ref) (reverse liveVars')
output'' (mkCallRel $ Imm 3) output'' (mkCallRel $ Imm 3)
output'' Return output'' Return
@ -148,11 +190,13 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
output'' Nop output'' Nop
setAP 0 setAP 0
setVars vars setVars vars
setBuiltinOffset bltOff
goLocalBlock :: Int -> HashMap Reg.VarRef Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r () goLocalBlock :: Int -> HashMap Reg.VarRef Int -> Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r ()
goLocalBlock ap0 vars liveVars mout' block = do goLocalBlock ap0 vars bltOff liveVars mout' block = do
setAP ap0 setAP ap0
setVars vars setVars vars
setBuiltinOffset bltOff
goBlock blts failLab liveVars mout' block goBlock blts failLab liveVars mout' block
---------------------------------------------------------------------- ----------------------------------------------------------------------
@ -179,6 +223,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
v <- lookupVar' vr v <- lookupVar' vr
return $ MemRef Fp (toOffset v) return $ MemRef Fp (toOffset v)
mkBuiltinRef :: Sem r MemRef
mkBuiltinRef = do
off <- getBuiltinOffset
return $ MemRef Fp (toOffset off)
mkRValue :: Reg.Value -> Sem r RValue mkRValue :: Reg.Value -> Sem r RValue
mkRValue = \case mkRValue = \case
Reg.ValConst c -> return $ Val $ Imm $ mkConst c Reg.ValConst c -> return $ Val $ Imm $ mkConst c
@ -216,6 +265,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
goAssignApValue :: Reg.Value -> Sem r () goAssignApValue :: Reg.Value -> Sem r ()
goAssignApValue v = mkRValue v >>= goAssignAp goAssignApValue v = mkRValue v >>= goAssignAp
goAssignApBuiltins :: Sem r ()
goAssignApBuiltins = mkBuiltinRef >>= goAssignAp . Val . Ref
goValue :: Reg.Value -> Sem r Value goValue :: Reg.Value -> Sem r Value
goValue = \case goValue = \case
Reg.ValConst c -> return $ Imm $ mkConst c Reg.ValConst c -> return $ Imm $ mkConst c
@ -414,16 +466,18 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
val <- mkMemRef _instrExtendClosureValue val <- mkMemRef _instrExtendClosureValue
goAssignAp (Val $ Ref val) goAssignAp (Val $ Ref val)
output'' $ mkCallRel $ Lab $ LabelRef (blts ^. stdlibExtendClosure) (Just (blts ^. stdlibExtendClosureName)) output'' $ mkCallRel $ Lab $ LabelRef (blts ^. stdlibExtendClosure) (Just (blts ^. stdlibExtendClosureName))
goCallBlock (Just _instrExtendClosureResult) liveVars goCallBlock False (Just _instrExtendClosureResult) liveVars
goCall' :: Reg.CallType -> [Reg.Value] -> Sem r () goCall' :: Reg.CallType -> [Reg.Value] -> Sem r ()
goCall' ct args = case ct of goCall' ct args = case ct of
Reg.CallFun sym -> do Reg.CallFun sym -> do
goAssignApBuiltins
mapM_ goAssignApValue (reverse args) mapM_ goAssignApValue (reverse args)
output'' $ mkCallRel $ Lab $ LabelRef sym (Just funName) output'' $ mkCallRel $ Lab $ LabelRef sym (Just funName)
where where
funName = quoteName (Reg.lookupFunInfo tab sym ^. Reg.functionName) funName = quoteName (Reg.lookupFunInfo tab sym ^. Reg.functionName)
Reg.CallClosure cl -> do Reg.CallClosure cl -> do
goAssignApBuiltins
mapM_ goAssignApValue (reverse args) mapM_ goAssignApValue (reverse args)
r <- mkMemRef cl r <- mkMemRef cl
goAssignAp (Val $ Ref r) goAssignAp (Val $ Ref r)
@ -432,7 +486,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
goCall :: HashSet Reg.VarRef -> Reg.InstrCall -> Sem r () goCall :: HashSet Reg.VarRef -> Reg.InstrCall -> Sem r ()
goCall liveVars Reg.InstrCall {..} = do goCall liveVars Reg.InstrCall {..} = do
goCall' _instrCallType _instrCallArgs goCall' _instrCallType _instrCallArgs
goCallBlock (Just _instrCallResult) liveVars goCallBlock True (Just _instrCallResult) liveVars
-- There is no way to make "proper" tail calls in Cairo, because -- There is no way to make "proper" tail calls in Cairo, because
-- the only way to set the `fp` register is via the `call` instruction. -- the only way to set the `fp` register is via the `call` instruction.
@ -444,6 +498,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
goReturn :: Reg.InstrReturn -> Sem r () goReturn :: Reg.InstrReturn -> Sem r ()
goReturn Reg.InstrReturn {..} = do goReturn Reg.InstrReturn {..} = do
goAssignApBuiltins
goAssignApValue _instrReturnValue goAssignApValue _instrReturnValue
output'' Return output'' Return
@ -462,14 +517,15 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
output'' $ mkJumpIf (Lab labFalse) r output'' $ mkJumpIf (Lab labFalse) r
ap0 <- getAP ap0 <- getAP
vars <- getVars vars <- getVars
goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchTrue bltOff <- getBuiltinOffset
goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchTrue
-- _instrBranchOutVar is Nothing iff the branch returns -- _instrBranchOutVar is Nothing iff the branch returns
when (isJust _instrBranchOutVar) $ when (isJust _instrBranchOutVar) $
output'' (mkJumpRel (Val $ Lab labEnd)) output'' (mkJumpRel (Val $ Lab labEnd))
addrFalse <- getPC addrFalse <- getPC
registerLabelAddress symFalse addrFalse registerLabelAddress symFalse addrFalse
output'' $ Label labFalse output'' $ Label labFalse
goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchFalse goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchFalse
addrEnd <- getPC addrEnd <- getPC
registerLabelAddress symEnd addrEnd registerLabelAddress symEnd addrEnd
output'' $ Label labEnd output'' $ Label labEnd
@ -501,10 +557,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
mapM_ output'' jmps' mapM_ output'' jmps'
ap0 <- getAP ap0 <- getAP
vars <- getVars vars <- getVars
mapM_ (goCaseBranch ap0 vars symMap labEnd) _instrCaseBranches bltOff <- getBuiltinOffset
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches
mapM_ (goDefaultLabel symMap) defaultTags mapM_ (goDefaultLabel symMap) defaultTags
whenJust _instrCaseDefault $ whenJust _instrCaseDefault $
goLocalBlock ap0 vars liveVars _instrCaseOutVar goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar
addrEnd <- getPC addrEnd <- getPC
registerLabelAddress symEnd addrEnd registerLabelAddress symEnd addrEnd
output'' $ Label labEnd output'' $ Label labEnd
@ -513,14 +570,14 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
ctrTags = HashSet.fromList $ map (^. Reg.caseBranchTag) _instrCaseBranches ctrTags = HashSet.fromList $ map (^. Reg.caseBranchTag) _instrCaseBranches
defaultTags = filter (not . flip HashSet.member ctrTags) tags defaultTags = filter (not . flip HashSet.member ctrTags) tags
goCaseBranch :: Int -> HashMap Reg.VarRef Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r () goCaseBranch :: Int -> HashMap Reg.VarRef Int -> Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r ()
goCaseBranch ap0 vars symMap labEnd Reg.CaseBranch {..} = do goCaseBranch ap0 vars bltOff symMap labEnd Reg.CaseBranch {..} = do
let sym = fromJust $ HashMap.lookup _caseBranchTag symMap let sym = fromJust $ HashMap.lookup _caseBranchTag symMap
lab = LabelRef sym Nothing lab = LabelRef sym Nothing
addr <- getPC addr <- getPC
registerLabelAddress sym addr registerLabelAddress sym addr
output'' $ Label lab output'' $ Label lab
goLocalBlock ap0 vars liveVars _instrCaseOutVar _caseBranchCode goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar _caseBranchCode
-- _instrCaseOutVar is Nothing iff the branch returns -- _instrCaseOutVar is Nothing iff the branch returns
when (isJust _instrCaseOutVar) $ when (isJust _instrCaseOutVar) $
output'' (mkJumpRel (Val $ Lab labEnd)) output'' (mkJumpRel (Val $ Lab labEnd))

View File

@ -14,27 +14,31 @@ data CasmBuilder :: Effect where
LookupVar :: VarRef -> CasmBuilder m (Maybe Int) LookupVar :: VarRef -> CasmBuilder m (Maybe Int)
GetVars :: CasmBuilder m (HashMap VarRef Int) GetVars :: CasmBuilder m (HashMap VarRef Int)
SetVars :: HashMap VarRef Int -> CasmBuilder m () SetVars :: HashMap VarRef Int -> CasmBuilder m ()
GetBuiltinOffset :: CasmBuilder m Int
SetBuiltinOffset :: Int -> CasmBuilder m ()
makeSem ''CasmBuilder makeSem ''CasmBuilder
data BuilderState = BuilderState data BuilderState = BuilderState
{ _statePC :: Address, { _statePC :: Address,
_stateAP :: Int, _stateAP :: Int,
_stateVarMap :: HashMap VarRef Int _stateVarMap :: HashMap VarRef Int,
_stateBuiltinOff :: Int
} }
makeLenses ''BuilderState makeLenses ''BuilderState
mkBuilderState :: Address -> HashMap VarRef Int -> BuilderState mkBuilderState :: Address -> HashMap VarRef Int -> Int -> BuilderState
mkBuilderState addr vars = mkBuilderState addr vars bltOff =
BuilderState BuilderState
{ _statePC = addr, { _statePC = addr,
_stateAP = 0, _stateAP = 0,
_stateVarMap = vars _stateVarMap = vars,
_stateBuiltinOff = bltOff
} }
runCasmBuilder :: Address -> HashMap VarRef Int -> Sem (CasmBuilder ': r) a -> Sem r a runCasmBuilder :: Address -> HashMap VarRef Int -> Int -> Sem (CasmBuilder ': r) a -> Sem r a
runCasmBuilder addr vars = fmap snd . runCasmBuilder' (mkBuilderState addr vars) runCasmBuilder addr vars bltOff = fmap snd . runCasmBuilder' (mkBuilderState addr vars bltOff)
runCasmBuilder' :: BuilderState -> Sem (CasmBuilder ': r) a -> Sem r (BuilderState, a) runCasmBuilder' :: BuilderState -> Sem (CasmBuilder ': r) a -> Sem r (BuilderState, a)
runCasmBuilder' bs = reinterpret (runState bs) interp runCasmBuilder' bs = reinterpret (runState bs) interp
@ -60,6 +64,10 @@ runCasmBuilder' bs = reinterpret (runState bs) interp
gets (^. stateVarMap) gets (^. stateVarMap)
SetVars vars -> do SetVars vars -> do
modify' (set stateVarMap vars) modify' (set stateVarMap vars)
GetBuiltinOffset -> do
gets (^. stateBuiltinOff)
SetBuiltinOffset bltOff -> do
modify' (set stateBuiltinOff bltOff)
lookupVar' :: (Member CasmBuilder r) => VarRef -> Sem r Int lookupVar' :: (Member CasmBuilder r) => VarRef -> Sem r Int
lookupVar' = lookupVar >=> return . fromJust lookupVar' = lookupVar >=> return . fromJust

View File

@ -17,6 +17,7 @@ import Juvix.Compiler.Backend.C qualified as C
import Juvix.Compiler.Backend.Cairo qualified as Cairo import Juvix.Compiler.Backend.Cairo qualified as Cairo
import Juvix.Compiler.Backend.Geb qualified as Geb import Juvix.Compiler.Backend.Geb qualified as Geb
import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR
import Juvix.Compiler.Casm.Data.Builtins qualified as Casm
import Juvix.Compiler.Casm.Data.Result qualified as Casm import Juvix.Compiler.Casm.Data.Result qualified as Casm
import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm
import Juvix.Compiler.Concrete.Data.Highlight.Input import Juvix.Compiler.Concrete.Data.Highlight.Input
@ -278,7 +279,10 @@ regToCasm :: Reg.InfoTable -> Sem r Casm.Result
regToCasm = Reg.toCasm >=> return . Casm.fromReg regToCasm = Reg.toCasm >=> return . Casm.fromReg
casmToCairo :: Casm.Result -> Sem r Cairo.Result casmToCairo :: Casm.Result -> Sem r Cairo.Result
casmToCairo Casm.Result {..} = return $ Cairo.serialize $ Cairo.fromCasm _resultCode casmToCairo Casm.Result {..} =
return
. Cairo.serialize (map Casm.builtinName _resultBuiltins)
$ Cairo.fromCasm _resultCode
regToCairo :: Reg.InfoTable -> Sem r Cairo.Result regToCairo :: Reg.InfoTable -> Sem r Cairo.Result
regToCairo = regToCasm >=> casmToCairo regToCairo = regToCasm >=> casmToCairo

View File

@ -991,3 +991,12 @@ functionsPlaceholder = "functionsLibrary_placeholder"
theFunctionsLibrary :: (IsString s) => s theFunctionsLibrary :: (IsString s) => s
theFunctionsLibrary = "the_functionsLibrary" theFunctionsLibrary = "the_functionsLibrary"
cairoRangeCheck :: (IsString s) => s
cairoRangeCheck = "range_check"
cairoPoseidon :: (IsString s) => s
cairoPoseidon = "poseidon"
cairoEcOp :: (IsString s) => s
cairoEcOp = "ec_op"

View File

@ -42,4 +42,4 @@ compileAssertionEntry adjustEntry root' bRunVM optLevel mainFile expectedFile st
step "Pretty print" step "Pretty print"
writeFileEnsureLn tmpFile (toPlainText $ ppProgram _resultCode) writeFileEnsureLn tmpFile (toPlainText $ ppProgram _resultCode)
) )
casmRunAssertion' bRunVM _resultLabelInfo _resultCode Nothing expectedFile step casmRunAssertion' bRunVM _resultLabelInfo _resultCode _resultBuiltins Nothing expectedFile step

View File

@ -2,6 +2,7 @@ module Casm.Run.Base where
import Base import Base
import Data.Aeson import Data.Aeson
import Juvix.Compiler.Casm.Data.Builtins
import Juvix.Compiler.Casm.Data.Result qualified as Casm import Juvix.Compiler.Casm.Data.Result qualified as Casm
import Juvix.Compiler.Casm.Error import Juvix.Compiler.Casm.Error
import Juvix.Compiler.Casm.Extra.InputInfo import Juvix.Compiler.Casm.Extra.InputInfo
@ -18,14 +19,14 @@ casmRunVM' dirPath outputFile inputFile = do
let args = maybe [] (\f -> ["--program_input", toFilePath f]) inputFile let args = maybe [] (\f -> ["--program_input", toFilePath f]) inputFile
R.readProcessCwd (toFilePath dirPath) "run_cairo_vm.sh" (toFilePath outputFile : args) "" R.readProcessCwd (toFilePath dirPath) "run_cairo_vm.sh" (toFilePath outputFile : args) ""
casmRunVM :: LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion casmRunVM :: LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
casmRunVM labi instrs inputFile expectedFile step = do casmRunVM labi instrs blts inputFile expectedFile step = do
step "Check run_cairo_vm.sh is on path" step "Check run_cairo_vm.sh is on path"
assertCmdExists $(mkRelFile "run_cairo_vm.sh") assertCmdExists $(mkRelFile "run_cairo_vm.sh")
withTempDir' withTempDir'
( \dirPath -> do ( \dirPath -> do
step "Serialize to Cairo bytecode" step "Serialize to Cairo bytecode"
let res = run $ casmToCairo (Casm.Result labi instrs) let res = run $ casmToCairo (Casm.Result labi instrs blts)
outputFile = dirPath <//> $(mkRelFile "out.json") outputFile = dirPath <//> $(mkRelFile "out.json")
encodeFile (toFilePath outputFile) res encodeFile (toFilePath outputFile) res
step "Run Cairo VM" step "Run Cairo VM"
@ -35,8 +36,8 @@ casmRunVM labi instrs inputFile expectedFile step = do
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
) )
casmRunAssertion' :: Bool -> LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion casmRunAssertion' :: Bool -> LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
casmRunAssertion' bRunVM labi instrs inputFile expectedFile step = casmRunAssertion' bRunVM labi instrs blts inputFile expectedFile step =
case validate labi instrs of case validate labi instrs of
Left err -> do Left err -> do
assertFailure (prettyString err) assertFailure (prettyString err)
@ -60,7 +61,7 @@ casmRunAssertion' bRunVM labi instrs inputFile expectedFile step =
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
) )
when bRunVM $ when bRunVM $
casmRunVM labi instrs inputFile expectedFile step casmRunVM labi instrs blts inputFile expectedFile step
casmRunAssertion :: Bool -> Path Abs File -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion casmRunAssertion :: Bool -> Path Abs File -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
casmRunAssertion bRunVM mainFile inputFile expectedFile step = do casmRunAssertion bRunVM mainFile inputFile expectedFile step = do
@ -68,7 +69,7 @@ casmRunAssertion bRunVM mainFile inputFile expectedFile step = do
r <- parseFile mainFile r <- parseFile mainFile
case r of case r of
Left err -> assertFailure (prettyString err) Left err -> assertFailure (prettyString err)
Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs inputFile expectedFile step Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs [] inputFile expectedFile step
casmRunErrorAssertion :: Path Abs File -> (String -> IO ()) -> Assertion casmRunErrorAssertion :: Path Abs File -> (String -> IO ()) -> Assertion
casmRunErrorAssertion mainFile step = do casmRunErrorAssertion mainFile step = do