mirror of
https://github.com/anoma/juvix.git
synced 2024-11-22 14:17:29 +03:00
Support for Cairo builtins (#2718)
This PR implements generic support for Cairo VM builtins. The calling convention in the generated CASM code is changed to allow for passing around the builtin pointers. Appropriate builtin initialization and finalization code is added. Support for specific builtins (e.g. Poseidon hash, range check, Elliptic Curve operation) still needs to be implemented in separate PRs. * Closes #2683
This commit is contained in:
parent
65176a333d
commit
ad76c7a583
@ -28,7 +28,7 @@ runCommand opts = do
|
||||
runReader entryPoint
|
||||
. runError @JuvixError
|
||||
. casmToCairo
|
||||
$ Casm.Result labi code
|
||||
$ Casm.Result labi code []
|
||||
res <- getRight r
|
||||
liftIO $ JSON.encodeFile (toFilePath cairoFile) res
|
||||
where
|
||||
|
@ -20,7 +20,7 @@ juvix_get_ap_reg:
|
||||
|
||||
-- [fp - 3]: closure
|
||||
-- [fp - 4]: n = the number of arguments to extend with
|
||||
-- [fp - 4 - k]: argument n - k - 1 (reverse order!)
|
||||
-- [fp - 4 - k]: argument n - k - 1 (reverse order!) (k is 0-based)
|
||||
juvix_extend_closure:
|
||||
-- copy stored args reversing them;
|
||||
-- to copy the stored args to the new closure
|
||||
@ -95,10 +95,14 @@ juvix_extend_closure:
|
||||
[ap] = [fp + 15]; ap++
|
||||
ret
|
||||
|
||||
-- [fp - 3]: closure; [fp - 3 - k]: argument k to closure call
|
||||
-- [fp - 3]: closure;
|
||||
-- [fp - 4 - k]: argument k to closure call (0-based)
|
||||
-- [fp - 4 - n]: builtin pointer, where n = number of supplied args
|
||||
juvix_call_closure:
|
||||
-- jmp rel (9 - argsnum)
|
||||
jmp rel [[fp - 3] + 2]
|
||||
-- builtin ptr + args
|
||||
[ap] = [fp - 12]; ap++
|
||||
[ap] = [fp - 11]; ap++
|
||||
[ap] = [fp - 10]; ap++
|
||||
[ap] = [fp - 9]; ap++
|
||||
|
@ -2,4 +2,4 @@
|
||||
|
||||
BASE=`basename "$1" .json`
|
||||
|
||||
juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=small
|
||||
juvix-cairo-vm "$@" --print_output --proof_mode --trace_file ${BASE}.trace --air_public_input=${BASE}_public_input.json --air_private_input=${BASE}_private_input.json --memory_file=${BASE}_memory.mem --layout=all_cairo
|
||||
|
@ -5,26 +5,29 @@ import Juvix.Compiler.Backend.Cairo.Data.Result
|
||||
import Juvix.Compiler.Backend.Cairo.Language
|
||||
import Numeric
|
||||
|
||||
serialize :: [Element] -> Result
|
||||
serialize elems =
|
||||
serialize :: [Text] -> [Element] -> Result
|
||||
serialize builtins elems =
|
||||
Result
|
||||
{ _resultData =
|
||||
initializeOutput
|
||||
initializeBuiltins
|
||||
++ map toHexText (serialize' elems)
|
||||
++ finalizeOutput
|
||||
++ finalizeBuiltins
|
||||
++ finalizeJump,
|
||||
_resultStart = 0,
|
||||
_resultEnd = length initializeOutput + length elems + length finalizeOutput,
|
||||
_resultEnd = length initializeBuiltins + length elems + length finalizeBuiltins,
|
||||
_resultMain = 0,
|
||||
_resultHints = hints,
|
||||
_resultBuiltins = ["output"]
|
||||
_resultBuiltins = "output" : builtins
|
||||
}
|
||||
where
|
||||
builtinsNum :: Natural
|
||||
builtinsNum = fromIntegral (length builtins)
|
||||
|
||||
hints :: [(Int, Text)]
|
||||
hints = catMaybes $ zipWith mkHint elems [0 ..]
|
||||
|
||||
pcShift :: Int
|
||||
pcShift = length initializeOutput
|
||||
pcShift = length initializeBuiltins
|
||||
|
||||
mkHint :: Element -> Int -> Maybe (Int, Text)
|
||||
mkHint el pc = case el of
|
||||
@ -34,21 +37,30 @@ serialize elems =
|
||||
toHexText :: Natural -> Text
|
||||
toHexText n = "0x" <> fromString (showHex n "")
|
||||
|
||||
initializeOutput :: [Text]
|
||||
initializeOutput =
|
||||
initializeBuiltins :: [Text]
|
||||
initializeBuiltins =
|
||||
-- ap += allBuiltinsNum
|
||||
[ "0x40480017fff7fff",
|
||||
"0x1"
|
||||
toHexText (builtinsNum + 1)
|
||||
]
|
||||
|
||||
finalizeOutput :: [Text]
|
||||
finalizeOutput =
|
||||
finalizeBuiltins :: [Text]
|
||||
finalizeBuiltins =
|
||||
-- [[fp]] = [ap - 1] -- [output_ptr] = [ap - 1]
|
||||
-- [ap] = [fp] + 1; ap++ -- output_ptr
|
||||
[ "0x4002800080007fff",
|
||||
"0x4826800180008000",
|
||||
"0x1"
|
||||
]
|
||||
++
|
||||
-- [ap] = [ap - builtinsNum - 2]; ap++
|
||||
replicate
|
||||
builtinsNum
|
||||
(toHexText (0x48107ffe7fff8000 - shift builtinsNum 32))
|
||||
|
||||
finalizeJump :: [Text]
|
||||
finalizeJump =
|
||||
-- jmp rel 0
|
||||
[ "0x10780017fff7fff",
|
||||
"0x0"
|
||||
]
|
||||
|
27
src/Juvix/Compiler/Casm/Data/Builtins.hs
Normal file
27
src/Juvix/Compiler/Casm/Data/Builtins.hs
Normal file
@ -0,0 +1,27 @@
|
||||
module Juvix.Compiler.Casm.Data.Builtins where
|
||||
|
||||
import Juvix.Extra.Strings qualified as Str
|
||||
import Juvix.Prelude
|
||||
|
||||
-- The order of the builtins must correspond to the "standard" builtin order in
|
||||
-- the Cairo VM implementation. See:
|
||||
-- https://github.com/lambdaclass/cairo-vm/blob/main/vm/src/vm/runners/cairo_runner.rs#L257
|
||||
data Builtin
|
||||
= BuiltinRangeCheck
|
||||
| BuiltinEcOp
|
||||
| BuiltinPoseidon
|
||||
deriving stock (Show, Eq, Generic, Enum, Bounded)
|
||||
|
||||
instance Hashable Builtin
|
||||
|
||||
builtinsNum :: Int
|
||||
builtinsNum = length (allElements @Builtin)
|
||||
|
||||
builtinName :: Builtin -> Text
|
||||
builtinName = \case
|
||||
BuiltinRangeCheck -> Str.cairoRangeCheck
|
||||
BuiltinEcOp -> Str.cairoEcOp
|
||||
BuiltinPoseidon -> Str.cairoPoseidon
|
||||
|
||||
builtinNames :: [Text]
|
||||
builtinNames = map builtinName allElements
|
@ -1,4 +1,8 @@
|
||||
module Juvix.Compiler.Casm.Data.LabelInfoBuilder where
|
||||
module Juvix.Compiler.Casm.Data.LabelInfoBuilder
|
||||
( module Juvix.Compiler.Casm.Data.LabelInfo,
|
||||
module Juvix.Compiler.Casm.Data.LabelInfoBuilder,
|
||||
)
|
||||
where
|
||||
|
||||
import Data.HashMap.Strict qualified as HashMap
|
||||
import Juvix.Compiler.Casm.Data.LabelInfo
|
||||
|
@ -1,11 +1,13 @@
|
||||
module Juvix.Compiler.Casm.Data.Result where
|
||||
|
||||
import Juvix.Compiler.Casm.Data.Builtins
|
||||
import Juvix.Compiler.Casm.Data.LabelInfo
|
||||
import Juvix.Compiler.Casm.Language
|
||||
|
||||
data Result = Result
|
||||
{ _resultLabelInfo :: LabelInfo,
|
||||
_resultCode :: [Instruction]
|
||||
_resultCode :: [Instruction],
|
||||
_resultBuiltins :: [Builtin]
|
||||
}
|
||||
|
||||
makeLenses ''Result
|
||||
|
@ -11,6 +11,7 @@ import Data.HashMap.Strict qualified as HashMap
|
||||
import Data.Vector qualified as Vec
|
||||
import Data.Vector.Mutable qualified as MV
|
||||
import GHC.IO qualified as GHC
|
||||
import Juvix.Compiler.Casm.Data.Builtins
|
||||
import Juvix.Compiler.Casm.Data.InputInfo
|
||||
import Juvix.Compiler.Casm.Data.LabelInfo
|
||||
import Juvix.Compiler.Casm.Error
|
||||
@ -39,7 +40,9 @@ hRunCode hout inputInfo (LabelInfo labelInfo) instrs0 = runST goCode
|
||||
goCode :: ST s FField
|
||||
goCode = do
|
||||
mem <- MV.replicate initialMemSize Nothing
|
||||
go 0 0 0 mem
|
||||
forM_ [0 .. builtinsNum] $ \k ->
|
||||
MV.write mem k (Just (fieldFromInteger cairoFieldSize 0))
|
||||
go 0 (builtinsNum + 1) 0 mem
|
||||
|
||||
go ::
|
||||
Address ->
|
||||
|
@ -5,7 +5,12 @@ import Juvix.Compiler.Casm.Data.Result
|
||||
import Juvix.Compiler.Casm.Language
|
||||
|
||||
fromCairo :: [Cairo.Element] -> Result
|
||||
fromCairo elems0 = Result mempty (go 0 [] elems0)
|
||||
fromCairo elems0 =
|
||||
Result
|
||||
{ _resultLabelInfo = mempty,
|
||||
_resultCode = go 0 [] elems0,
|
||||
_resultBuiltins = mempty
|
||||
}
|
||||
where
|
||||
errorMsg :: Address -> Text -> a
|
||||
errorMsg addr msg = error ("error at address " <> show addr <> ": " <> msg)
|
||||
|
@ -3,6 +3,7 @@ module Juvix.Compiler.Casm.Translation.FromReg where
|
||||
import Data.HashMap.Strict qualified as HashMap
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Data.Text qualified as Text
|
||||
import Juvix.Compiler.Casm.Data.Builtins
|
||||
import Juvix.Compiler.Casm.Data.LabelInfoBuilder
|
||||
import Juvix.Compiler.Casm.Data.Limits
|
||||
import Juvix.Compiler.Casm.Data.Result
|
||||
@ -17,26 +18,56 @@ import Juvix.Compiler.Tree.Evaluator.Builtins qualified as Reg
|
||||
import Juvix.Data.Field
|
||||
|
||||
fromReg :: Reg.InfoTable -> Result
|
||||
fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do
|
||||
fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolId tab) $ do
|
||||
let startAddr :: Address = 2
|
||||
startSym <- freshSymbol
|
||||
endSym <- freshSymbol
|
||||
let startName :: Text = "__juvix_start"
|
||||
startLab = LabelRef startSym (Just startName)
|
||||
endName :: Text = "__juvix_end"
|
||||
endLab = LabelRef endSym (Just endName)
|
||||
registerLabelName startSym startName
|
||||
registerLabelAddress startSym startAddr
|
||||
let mainSym = fromJust $ tab ^. Reg.infoMainFunction
|
||||
mainInfo = fromJust (HashMap.lookup mainSym (tab ^. Reg.infoFunctions))
|
||||
mainName = mainInfo ^. Reg.functionName
|
||||
mainArgs = getInputArgs (mainInfo ^. Reg.functionArgsNum) (mainInfo ^. Reg.functionArgNames)
|
||||
initialOffset = length mainArgs + 2
|
||||
(blts, binstrs) <- addStdlibBuiltins initialOffset
|
||||
bnum = toOffset builtinsNum
|
||||
callStartInstr = mkCallRel (Lab startLab)
|
||||
initBuiltinsInstr = mkAssignAp (Binop $ BinopValue FieldAdd (MemRef Fp (-2)) (Imm 1))
|
||||
callMainInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName))
|
||||
jmpEndInstr = mkJumpRel (Val $ Lab endLab)
|
||||
margs = reverse $ map (Hint . HintInput) mainArgs
|
||||
-- [ap] = [[ap - 2 - k] + k]; ap++
|
||||
bltsRet = map (\k -> mkAssignAp (Load $ LoadValue (MemRef Ap (-2 - k)) k)) [0 .. bnum - 1]
|
||||
resRetInstr = mkAssignAp (Val $ Ref $ MemRef Ap (-bnum - 1))
|
||||
pinstrs =
|
||||
callStartInstr
|
||||
: jmpEndInstr
|
||||
: Label startLab
|
||||
: initBuiltinsInstr
|
||||
: margs
|
||||
++ callMainInstr
|
||||
: bltsRet
|
||||
++ [resRetInstr, Return]
|
||||
(blts, binstrs) <- addStdlibBuiltins (length pinstrs)
|
||||
let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs)
|
||||
endSym <- freshSymbol
|
||||
let endName :: Text = "__juvix_end"
|
||||
endLab = LabelRef endSym (Just endName)
|
||||
(addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (initialOffset + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
|
||||
eassert (addr == length instrs + length cinstrs + length binstrs + initialOffset)
|
||||
(addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (length pinstrs + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
|
||||
eassert (addr == length instrs + length cinstrs + length binstrs + length pinstrs)
|
||||
registerLabelName endSym endName
|
||||
registerLabelAddress endSym addr
|
||||
let callInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName))
|
||||
jmpInstr = mkJumpRel (Val $ Lab endLab)
|
||||
margs = reverse $ map (Hint . HintInput) mainArgs
|
||||
return $ margs ++ callInstr : jmpInstr : binstrs ++ cinstrs ++ instrs ++ [Label endLab]
|
||||
return $
|
||||
( allElements,
|
||||
pinstrs
|
||||
++ binstrs
|
||||
++ cinstrs
|
||||
++ instrs
|
||||
++ [Label endLab]
|
||||
)
|
||||
where
|
||||
mkResult :: (LabelInfo, ([Builtin], Code)) -> Result
|
||||
mkResult (labi, (blts, code)) = Result labi code blts
|
||||
|
||||
info :: Reg.ExtraInfo
|
||||
info = Reg.computeExtraInfo tab
|
||||
|
||||
@ -72,6 +103,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
("fp", "__fp__")
|
||||
]
|
||||
|
||||
argsOffset :: Int
|
||||
argsOffset = 3
|
||||
|
||||
goFun :: forall r. (Member LabelInfoBuilder r) => StdlibBuiltins -> LabelRef -> (Address, [[Instruction]]) -> Reg.FunctionInfo -> Sem r (Address, [[Instruction]])
|
||||
goFun blts failLab (addr0, acc) funInfo = do
|
||||
let sym = funInfo ^. Reg.functionSymbol
|
||||
@ -85,10 +119,10 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
n = funInfo ^. Reg.functionArgsNum
|
||||
let vars =
|
||||
HashMap.fromList $
|
||||
map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -3 - k)) [0 .. n - 1]
|
||||
map (\k -> (Reg.VarRef Reg.VarGroupArgs k Nothing, -argsOffset - k)) [0 .. n - 1]
|
||||
instrs <-
|
||||
fmap fst
|
||||
. runCasmBuilder addr1 vars
|
||||
. runCasmBuilder addr1 vars (-argsOffset - n)
|
||||
. runOutputList
|
||||
$ goBlock blts failLab mempty Nothing block
|
||||
return (addr1 + length instrs, (pre ++ instrs) : acc)
|
||||
@ -119,7 +153,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
Nothing -> do
|
||||
eassert (isJust mout)
|
||||
eassert (HashSet.member (fromJust mout) liveVars0)
|
||||
goCallBlock Nothing liveVars0
|
||||
goCallBlock False Nothing liveVars0
|
||||
where
|
||||
output'' :: Instruction -> Sem r ()
|
||||
output'' i = do
|
||||
@ -131,14 +165,22 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
output'' i
|
||||
incAP apOff
|
||||
|
||||
goCallBlock :: Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r ()
|
||||
goCallBlock outVar liveVars = do
|
||||
goCallBlock :: Bool -> Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r ()
|
||||
goCallBlock updatedBuiltins outVar liveVars = do
|
||||
let liveVars' = toList (maybe liveVars (flip HashSet.delete liveVars) outVar)
|
||||
n = length liveVars'
|
||||
bltOff =
|
||||
if
|
||||
| updatedBuiltins ->
|
||||
-argsOffset - n - fromEnum (isJust outVar)
|
||||
| otherwise ->
|
||||
-argsOffset - n
|
||||
vars =
|
||||
HashMap.fromList $
|
||||
maybe [] (\var -> [(var, -3 - n)]) outVar
|
||||
++ zipWithExact (\var k -> (var, -3 - k)) liveVars' [0 .. n - 1]
|
||||
maybe [] (\var -> [(var, -argsOffset - n - if updatedBuiltins then 0 else 1)]) outVar
|
||||
++ zipWithExact (\var k -> (var, -argsOffset - k)) liveVars' [0 .. n - 1]
|
||||
unless updatedBuiltins $
|
||||
goAssignApBuiltins
|
||||
mapM_ (mkMemRef >=> goAssignAp . Val . Ref) (reverse liveVars')
|
||||
output'' (mkCallRel $ Imm 3)
|
||||
output'' Return
|
||||
@ -148,11 +190,13 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
output'' Nop
|
||||
setAP 0
|
||||
setVars vars
|
||||
setBuiltinOffset bltOff
|
||||
|
||||
goLocalBlock :: Int -> HashMap Reg.VarRef Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r ()
|
||||
goLocalBlock ap0 vars liveVars mout' block = do
|
||||
goLocalBlock :: Int -> HashMap Reg.VarRef Int -> Int -> HashSet Reg.VarRef -> Maybe Reg.VarRef -> Reg.Block -> Sem r ()
|
||||
goLocalBlock ap0 vars bltOff liveVars mout' block = do
|
||||
setAP ap0
|
||||
setVars vars
|
||||
setBuiltinOffset bltOff
|
||||
goBlock blts failLab liveVars mout' block
|
||||
|
||||
----------------------------------------------------------------------
|
||||
@ -179,6 +223,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
v <- lookupVar' vr
|
||||
return $ MemRef Fp (toOffset v)
|
||||
|
||||
mkBuiltinRef :: Sem r MemRef
|
||||
mkBuiltinRef = do
|
||||
off <- getBuiltinOffset
|
||||
return $ MemRef Fp (toOffset off)
|
||||
|
||||
mkRValue :: Reg.Value -> Sem r RValue
|
||||
mkRValue = \case
|
||||
Reg.ValConst c -> return $ Val $ Imm $ mkConst c
|
||||
@ -216,6 +265,9 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
goAssignApValue :: Reg.Value -> Sem r ()
|
||||
goAssignApValue v = mkRValue v >>= goAssignAp
|
||||
|
||||
goAssignApBuiltins :: Sem r ()
|
||||
goAssignApBuiltins = mkBuiltinRef >>= goAssignAp . Val . Ref
|
||||
|
||||
goValue :: Reg.Value -> Sem r Value
|
||||
goValue = \case
|
||||
Reg.ValConst c -> return $ Imm $ mkConst c
|
||||
@ -414,16 +466,18 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
val <- mkMemRef _instrExtendClosureValue
|
||||
goAssignAp (Val $ Ref val)
|
||||
output'' $ mkCallRel $ Lab $ LabelRef (blts ^. stdlibExtendClosure) (Just (blts ^. stdlibExtendClosureName))
|
||||
goCallBlock (Just _instrExtendClosureResult) liveVars
|
||||
goCallBlock False (Just _instrExtendClosureResult) liveVars
|
||||
|
||||
goCall' :: Reg.CallType -> [Reg.Value] -> Sem r ()
|
||||
goCall' ct args = case ct of
|
||||
Reg.CallFun sym -> do
|
||||
goAssignApBuiltins
|
||||
mapM_ goAssignApValue (reverse args)
|
||||
output'' $ mkCallRel $ Lab $ LabelRef sym (Just funName)
|
||||
where
|
||||
funName = quoteName (Reg.lookupFunInfo tab sym ^. Reg.functionName)
|
||||
Reg.CallClosure cl -> do
|
||||
goAssignApBuiltins
|
||||
mapM_ goAssignApValue (reverse args)
|
||||
r <- mkMemRef cl
|
||||
goAssignAp (Val $ Ref r)
|
||||
@ -432,7 +486,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
goCall :: HashSet Reg.VarRef -> Reg.InstrCall -> Sem r ()
|
||||
goCall liveVars Reg.InstrCall {..} = do
|
||||
goCall' _instrCallType _instrCallArgs
|
||||
goCallBlock (Just _instrCallResult) liveVars
|
||||
goCallBlock True (Just _instrCallResult) liveVars
|
||||
|
||||
-- There is no way to make "proper" tail calls in Cairo, because
|
||||
-- the only way to set the `fp` register is via the `call` instruction.
|
||||
@ -444,6 +498,7 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
|
||||
goReturn :: Reg.InstrReturn -> Sem r ()
|
||||
goReturn Reg.InstrReturn {..} = do
|
||||
goAssignApBuiltins
|
||||
goAssignApValue _instrReturnValue
|
||||
output'' Return
|
||||
|
||||
@ -462,14 +517,15 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
output'' $ mkJumpIf (Lab labFalse) r
|
||||
ap0 <- getAP
|
||||
vars <- getVars
|
||||
goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchTrue
|
||||
bltOff <- getBuiltinOffset
|
||||
goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchTrue
|
||||
-- _instrBranchOutVar is Nothing iff the branch returns
|
||||
when (isJust _instrBranchOutVar) $
|
||||
output'' (mkJumpRel (Val $ Lab labEnd))
|
||||
addrFalse <- getPC
|
||||
registerLabelAddress symFalse addrFalse
|
||||
output'' $ Label labFalse
|
||||
goLocalBlock ap0 vars liveVars _instrBranchOutVar _instrBranchFalse
|
||||
goLocalBlock ap0 vars bltOff liveVars _instrBranchOutVar _instrBranchFalse
|
||||
addrEnd <- getPC
|
||||
registerLabelAddress symEnd addrEnd
|
||||
output'' $ Label labEnd
|
||||
@ -501,10 +557,11 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
mapM_ output'' jmps'
|
||||
ap0 <- getAP
|
||||
vars <- getVars
|
||||
mapM_ (goCaseBranch ap0 vars symMap labEnd) _instrCaseBranches
|
||||
bltOff <- getBuiltinOffset
|
||||
mapM_ (goCaseBranch ap0 vars bltOff symMap labEnd) _instrCaseBranches
|
||||
mapM_ (goDefaultLabel symMap) defaultTags
|
||||
whenJust _instrCaseDefault $
|
||||
goLocalBlock ap0 vars liveVars _instrCaseOutVar
|
||||
goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar
|
||||
addrEnd <- getPC
|
||||
registerLabelAddress symEnd addrEnd
|
||||
output'' $ Label labEnd
|
||||
@ -513,14 +570,14 @@ fromReg tab = uncurry Result $ run $ runLabelInfoBuilderWithNextId (Reg.getNextS
|
||||
ctrTags = HashSet.fromList $ map (^. Reg.caseBranchTag) _instrCaseBranches
|
||||
defaultTags = filter (not . flip HashSet.member ctrTags) tags
|
||||
|
||||
goCaseBranch :: Int -> HashMap Reg.VarRef Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r ()
|
||||
goCaseBranch ap0 vars symMap labEnd Reg.CaseBranch {..} = do
|
||||
goCaseBranch :: Int -> HashMap Reg.VarRef Int -> Int -> HashMap Tag Symbol -> LabelRef -> Reg.CaseBranch -> Sem r ()
|
||||
goCaseBranch ap0 vars bltOff symMap labEnd Reg.CaseBranch {..} = do
|
||||
let sym = fromJust $ HashMap.lookup _caseBranchTag symMap
|
||||
lab = LabelRef sym Nothing
|
||||
addr <- getPC
|
||||
registerLabelAddress sym addr
|
||||
output'' $ Label lab
|
||||
goLocalBlock ap0 vars liveVars _instrCaseOutVar _caseBranchCode
|
||||
goLocalBlock ap0 vars bltOff liveVars _instrCaseOutVar _caseBranchCode
|
||||
-- _instrCaseOutVar is Nothing iff the branch returns
|
||||
when (isJust _instrCaseOutVar) $
|
||||
output'' (mkJumpRel (Val $ Lab labEnd))
|
||||
|
@ -14,27 +14,31 @@ data CasmBuilder :: Effect where
|
||||
LookupVar :: VarRef -> CasmBuilder m (Maybe Int)
|
||||
GetVars :: CasmBuilder m (HashMap VarRef Int)
|
||||
SetVars :: HashMap VarRef Int -> CasmBuilder m ()
|
||||
GetBuiltinOffset :: CasmBuilder m Int
|
||||
SetBuiltinOffset :: Int -> CasmBuilder m ()
|
||||
|
||||
makeSem ''CasmBuilder
|
||||
|
||||
data BuilderState = BuilderState
|
||||
{ _statePC :: Address,
|
||||
_stateAP :: Int,
|
||||
_stateVarMap :: HashMap VarRef Int
|
||||
_stateVarMap :: HashMap VarRef Int,
|
||||
_stateBuiltinOff :: Int
|
||||
}
|
||||
|
||||
makeLenses ''BuilderState
|
||||
|
||||
mkBuilderState :: Address -> HashMap VarRef Int -> BuilderState
|
||||
mkBuilderState addr vars =
|
||||
mkBuilderState :: Address -> HashMap VarRef Int -> Int -> BuilderState
|
||||
mkBuilderState addr vars bltOff =
|
||||
BuilderState
|
||||
{ _statePC = addr,
|
||||
_stateAP = 0,
|
||||
_stateVarMap = vars
|
||||
_stateVarMap = vars,
|
||||
_stateBuiltinOff = bltOff
|
||||
}
|
||||
|
||||
runCasmBuilder :: Address -> HashMap VarRef Int -> Sem (CasmBuilder ': r) a -> Sem r a
|
||||
runCasmBuilder addr vars = fmap snd . runCasmBuilder' (mkBuilderState addr vars)
|
||||
runCasmBuilder :: Address -> HashMap VarRef Int -> Int -> Sem (CasmBuilder ': r) a -> Sem r a
|
||||
runCasmBuilder addr vars bltOff = fmap snd . runCasmBuilder' (mkBuilderState addr vars bltOff)
|
||||
|
||||
runCasmBuilder' :: BuilderState -> Sem (CasmBuilder ': r) a -> Sem r (BuilderState, a)
|
||||
runCasmBuilder' bs = reinterpret (runState bs) interp
|
||||
@ -60,6 +64,10 @@ runCasmBuilder' bs = reinterpret (runState bs) interp
|
||||
gets (^. stateVarMap)
|
||||
SetVars vars -> do
|
||||
modify' (set stateVarMap vars)
|
||||
GetBuiltinOffset -> do
|
||||
gets (^. stateBuiltinOff)
|
||||
SetBuiltinOffset bltOff -> do
|
||||
modify' (set stateBuiltinOff bltOff)
|
||||
|
||||
lookupVar' :: (Member CasmBuilder r) => VarRef -> Sem r Int
|
||||
lookupVar' = lookupVar >=> return . fromJust
|
||||
|
@ -17,6 +17,7 @@ import Juvix.Compiler.Backend.C qualified as C
|
||||
import Juvix.Compiler.Backend.Cairo qualified as Cairo
|
||||
import Juvix.Compiler.Backend.Geb qualified as Geb
|
||||
import Juvix.Compiler.Backend.VampIR.Translation qualified as VampIR
|
||||
import Juvix.Compiler.Casm.Data.Builtins qualified as Casm
|
||||
import Juvix.Compiler.Casm.Data.Result qualified as Casm
|
||||
import Juvix.Compiler.Casm.Translation.FromReg qualified as Casm
|
||||
import Juvix.Compiler.Concrete.Data.Highlight.Input
|
||||
@ -278,7 +279,10 @@ regToCasm :: Reg.InfoTable -> Sem r Casm.Result
|
||||
regToCasm = Reg.toCasm >=> return . Casm.fromReg
|
||||
|
||||
casmToCairo :: Casm.Result -> Sem r Cairo.Result
|
||||
casmToCairo Casm.Result {..} = return $ Cairo.serialize $ Cairo.fromCasm _resultCode
|
||||
casmToCairo Casm.Result {..} =
|
||||
return
|
||||
. Cairo.serialize (map Casm.builtinName _resultBuiltins)
|
||||
$ Cairo.fromCasm _resultCode
|
||||
|
||||
regToCairo :: Reg.InfoTable -> Sem r Cairo.Result
|
||||
regToCairo = regToCasm >=> casmToCairo
|
||||
|
@ -991,3 +991,12 @@ functionsPlaceholder = "functionsLibrary_placeholder"
|
||||
|
||||
theFunctionsLibrary :: (IsString s) => s
|
||||
theFunctionsLibrary = "the_functionsLibrary"
|
||||
|
||||
cairoRangeCheck :: (IsString s) => s
|
||||
cairoRangeCheck = "range_check"
|
||||
|
||||
cairoPoseidon :: (IsString s) => s
|
||||
cairoPoseidon = "poseidon"
|
||||
|
||||
cairoEcOp :: (IsString s) => s
|
||||
cairoEcOp = "ec_op"
|
||||
|
@ -42,4 +42,4 @@ compileAssertionEntry adjustEntry root' bRunVM optLevel mainFile expectedFile st
|
||||
step "Pretty print"
|
||||
writeFileEnsureLn tmpFile (toPlainText $ ppProgram _resultCode)
|
||||
)
|
||||
casmRunAssertion' bRunVM _resultLabelInfo _resultCode Nothing expectedFile step
|
||||
casmRunAssertion' bRunVM _resultLabelInfo _resultCode _resultBuiltins Nothing expectedFile step
|
||||
|
@ -2,6 +2,7 @@ module Casm.Run.Base where
|
||||
|
||||
import Base
|
||||
import Data.Aeson
|
||||
import Juvix.Compiler.Casm.Data.Builtins
|
||||
import Juvix.Compiler.Casm.Data.Result qualified as Casm
|
||||
import Juvix.Compiler.Casm.Error
|
||||
import Juvix.Compiler.Casm.Extra.InputInfo
|
||||
@ -18,14 +19,14 @@ casmRunVM' dirPath outputFile inputFile = do
|
||||
let args = maybe [] (\f -> ["--program_input", toFilePath f]) inputFile
|
||||
R.readProcessCwd (toFilePath dirPath) "run_cairo_vm.sh" (toFilePath outputFile : args) ""
|
||||
|
||||
casmRunVM :: LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunVM labi instrs inputFile expectedFile step = do
|
||||
casmRunVM :: LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunVM labi instrs blts inputFile expectedFile step = do
|
||||
step "Check run_cairo_vm.sh is on path"
|
||||
assertCmdExists $(mkRelFile "run_cairo_vm.sh")
|
||||
withTempDir'
|
||||
( \dirPath -> do
|
||||
step "Serialize to Cairo bytecode"
|
||||
let res = run $ casmToCairo (Casm.Result labi instrs)
|
||||
let res = run $ casmToCairo (Casm.Result labi instrs blts)
|
||||
outputFile = dirPath <//> $(mkRelFile "out.json")
|
||||
encodeFile (toFilePath outputFile) res
|
||||
step "Run Cairo VM"
|
||||
@ -35,8 +36,8 @@ casmRunVM labi instrs inputFile expectedFile step = do
|
||||
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
|
||||
)
|
||||
|
||||
casmRunAssertion' :: Bool -> LabelInfo -> Code -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunAssertion' bRunVM labi instrs inputFile expectedFile step =
|
||||
casmRunAssertion' :: Bool -> LabelInfo -> Code -> [Builtin] -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunAssertion' bRunVM labi instrs blts inputFile expectedFile step =
|
||||
case validate labi instrs of
|
||||
Left err -> do
|
||||
assertFailure (prettyString err)
|
||||
@ -60,7 +61,7 @@ casmRunAssertion' bRunVM labi instrs inputFile expectedFile step =
|
||||
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
|
||||
)
|
||||
when bRunVM $
|
||||
casmRunVM labi instrs inputFile expectedFile step
|
||||
casmRunVM labi instrs blts inputFile expectedFile step
|
||||
|
||||
casmRunAssertion :: Bool -> Path Abs File -> Maybe (Path Abs File) -> Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunAssertion bRunVM mainFile inputFile expectedFile step = do
|
||||
@ -68,7 +69,7 @@ casmRunAssertion bRunVM mainFile inputFile expectedFile step = do
|
||||
r <- parseFile mainFile
|
||||
case r of
|
||||
Left err -> assertFailure (prettyString err)
|
||||
Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs inputFile expectedFile step
|
||||
Right (labi, instrs) -> casmRunAssertion' bRunVM labi instrs [] inputFile expectedFile step
|
||||
|
||||
casmRunErrorAssertion :: Path Abs File -> (String -> IO ()) -> Assertion
|
||||
casmRunErrorAssertion mainFile step = do
|
||||
|
Loading…
Reference in New Issue
Block a user