Just cosmeic tweaks

This commit is contained in:
Iavor Diatchki 2022-09-22 16:17:22 +03:00
parent adfb760cad
commit cb43685abc

View File

@ -27,7 +27,7 @@ import Cryptol.TypeCheck.FFI.FFIType
#ifdef FFI_ENABLED
import Control.Exception
import Control.Exception(bracket_)
import Data.Either
import Data.Foldable
import Data.IORef
@ -159,29 +159,45 @@ foreignPrim name FFIFunType {..} impl tenv = buildFun ffiArgTypes []
--
-- NOTE: the result must be used only in the callback since it may have a
-- limited lifetime (e.g. pointer returned by alloca).
marshalArg :: FFIType -> GenValue Concrete ->
([SomeFFIArg] -> Eval a) -> Eval a
marshalArg FFIBool val f = f [SomeFFIArg @Word8 $ fromBool $ fromVBit val]
marshalArg (FFIBasic (FFIBasicVal t)) val f = getMarshalBasicValArg t \m -> do
arg <- m val
f [SomeFFIArg arg]
marshalArg (FFIBasic (FFIBasicRef t)) val f = getMarshalBasicRefArg t \m ->
marshalArg ::
FFIType ->
GenValue Concrete ->
([SomeFFIArg] -> Eval a) ->
Eval a
marshalArg FFIBool val f = f [SomeFFIArg @Word8 (fromBool (fromVBit val))]
marshalArg (FFIBasic (FFIBasicVal t)) val f =
getMarshalBasicValArg t \doExport ->
do arg <- doExport val
f [SomeFFIArg arg]
marshalArg (FFIBasic (FFIBasicRef t)) val f =
getMarshalBasicRefArg t \doExport ->
-- Since we need to do Eval actions in an IO callback, we need to manually
-- unwrap and wrap the Eval datatype
Eval \stk ->
m val \x ->
with x \ptr ->
runEval stk $ f [SomeFFIArg ptr]
doExport val \arg ->
with arg \ptr ->
runEval stk (f [SomeFFIArg ptr])
marshalArg (FFIArray (map evalFinType -> sizes) bt) val f =
case bt of
FFIBasicVal t -> getMarshalBasicValArg t \m ->
Eval \stk ->
marshalArrayArg stk \v g ->
runEval stk (m v) >>= g
FFIBasicVal t ->
getMarshalBasicValArg t \doExport ->
-- Since we need to do Eval actions in an IO callback,
-- we need to manually unwrap and wrap the Eval datatype
Eval \stk ->
marshalArrayArg stk \v k ->
k =<< runEval stk (doExport v)
FFIBasicRef t -> Eval \stk ->
getMarshalBasicRefArg t $ marshalArrayArg stk
where marshalArrayArg stk m =
allocaArray (fromInteger $ product sizes) \ptr -> do
getMarshalBasicRefArg t \doExport ->
marshalArrayArg stk doExport
where marshalArrayArg stk doExport =
allocaArray (fromInteger (product sizes)) \ptr -> do
-- Traverse the nested sequences and write the elements to the
-- array in order.
-- ns is the dimensions of the values we are currently
@ -191,69 +207,106 @@ foreignPrim name FFIFunType {..} impl tenv = buildFun ffiArgTypes []
-- that we push onto when we start processing a nested sequence
-- and pop off when we finish processing the current ones.
-- i is the index into the array.
let write (n:ns) (v:vs) nvss !i = do
vs' <- traverse (runEval stk) $
enumerateSeqMap n $ fromVSeq v
write ns vs' ((n, vs):nvss) i
write [] (v:vs) nvss !i = m v \x -> do
pokeElemOff ptr i x
write [] vs nvss (i + 1)
let
-- write next element of multi-dimensional array
write (n:ns) (v:vs) nvss !i =
do vs' <- traverse (runEval stk)
(enumerateSeqMap n (fromVSeq v))
write ns vs' ((n, vs):nvss) i
-- write next element in flat array
write [] (v:vs) nvss !i =
doExport v \rep ->
do pokeElemOff ptr i rep
write [] vs nvss (i + 1)
-- finished with flat array, do next element of multi-d array
write ns [] ((n, vs):nvss) !i = write (n:ns) vs nvss i
-- done
write _ _ [] _ = pure ()
write sizes [val] [] 0
runEval stk $ f [SomeFFIArg ptr]
marshalArg (FFITuple types) val f = do
vals <- sequence $ fromVTuple val
marshalArgs (zip types vals) f
marshalArg (FFIRecord typeMap) val f = do
vals <- traverse (`lookupRecord` val) $ displayOrder typeMap
marshalArgs (zip (displayElements typeMap) vals) f
marshalArg (FFITuple types) val f =
do vals <- sequence (fromVTuple val)
marshalArgs (types `zip` vals) f
marshalArg (FFIRecord typeMap) val f =
do vals <- traverse (`lookupRecord` val) (displayOrder typeMap)
marshalArgs (displayElements typeMap `zip` vals) f
-- Call marshalArg on a bunch of arguments and collect the results together
-- (in the order of the arguments).
marshalArgs :: [(FFIType, GenValue Concrete)] ->
([SomeFFIArg] -> Eval a) -> Eval a
marshalArgs ::
[(FFIType, GenValue Concrete)] ->
([SomeFFIArg] -> Eval a) ->
Eval a
marshalArgs typesAndVals f = go typesAndVals []
where go [] args = f args
go ((t, v):tvs) prevArgs = marshalArg t v \currArgs ->
go tvs (prevArgs ++ currArgs)
where
go [] args = f (concat (reverse args))
go ((t, v):tvs) prevArgs =
marshalArg t v \currArgs ->
go tvs (currArgs : prevArgs)
-- Given an FFIType and a GetRet, obtain a return value and convert it to a
-- Cryptol value. The return value is obtained differently depending on the
-- FFIType.
marshalRet :: FFIType -> GetRet -> Eval (GenValue Concrete)
marshalRet FFIBool gr = VBit . toBool <$> io (getRetAsValue gr @Word8)
marshalRet FFIBool gr =
do rep <- io (getRetAsValue gr @Word8)
pure (VBit (toBool rep))
marshalRet (FFIBasic (FFIBasicVal t)) gr =
getMarshalBasicValRet t (io (getRetAsValue gr) >>=)
getMarshalBasicValRet t \doImport ->
do rep <- io (getRetAsValue gr)
doImport rep
marshalRet (FFIBasic (FFIBasicRef t)) gr =
getBasicRefRet t \BasicRefRet {..} ->
Eval \stk ->
alloca \ptr ->
bracket_ (initBasicRefRet ptr) (clearBasicRefRet ptr) do
getRetAsOutArgs gr [SomeFFIArg ptr]
peek ptr >>= runEval stk . marshalBasicRefRet
marshalRet (FFIArray (map evalFinType -> sizes) bt) gr = Eval \stk -> do
let totalSize = fromInteger $ product sizes
getBasicRefRet t \how ->
Eval \stk ->
alloca \ptr ->
bracket_ (initBasicRefRet how ptr) (clearBasicRefRet how ptr)
do getRetAsOutArgs gr [SomeFFIArg ptr]
rep <- peek ptr
runEval stk (marshalBasicRefRet how rep)
marshalRet (FFIArray (map evalFinType -> sizes) bt) gr =
Eval \stk -> do
let totalSize = fromInteger (product sizes)
getResult marshal ptr = do
getRetAsOutArgs gr [SomeFFIArg ptr]
let build (n:ns) !i = do
-- We need to be careful to actually run this here and not just
-- stick the IO action into the sequence with io, or else we
-- will read from the array after it is deallocated.
vs <- for [0 .. fromInteger n - 1] \j ->
build ns (i * fromInteger n + j)
pure $ VSeq n $ finiteSeqMap Concrete $ map pure vs
pure (VSeq n (finiteSeqMap Concrete (map pure vs)))
build [] !i = peekElemOff ptr i >>= runEval stk . marshal
build sizes 0
case bt of
FFIBasicVal t -> getMarshalBasicValRet t \m ->
allocaArray totalSize $ getResult m
FFIBasicRef t -> getBasicRefRet t \BasicRefRet {..} ->
allocaArray totalSize \ptr -> do
let forEach f = for_ [0 .. totalSize - 1] $ f . advancePtr ptr
bracket_ (forEach initBasicRefRet) (forEach clearBasicRefRet) $
getResult marshalBasicRefRet ptr
FFIBasicVal t ->
getMarshalBasicValRet t \doImport ->
allocaArray totalSize (getResult doImport)
FFIBasicRef t ->
getBasicRefRet t \how ->
allocaArray totalSize \ptr ->
do let forEach f = for_ [0 .. totalSize - 1] (f . advancePtr ptr)
bracket_ (forEach (initBasicRefRet how))
(forEach (clearBasicRefRet how))
(getResult (marshalBasicRefRet how) ptr)
marshalRet (FFITuple types) gr = VTuple <$> marshalMultiRet types gr
marshalRet (FFIRecord typeMap) gr =
VRecord . recordFromFields . zip (displayOrder typeMap) <$>
marshalMultiRet (displayElements typeMap) gr
@ -318,10 +371,17 @@ getRetFromAsOutArgs f = GetRet
-- that marshals values to the 'FFIArg' type corresponding to the
-- 'FFIBasicValType'. The callback must be able to handle marshalling functions
-- that marshal to any 'FFIArg' type.
getMarshalBasicValArg :: FFIBasicValType ->
(forall a. FFIArg a => (GenValue Concrete -> Eval a) -> b) -> b
getMarshalBasicValArg ::
FFIBasicValType ->
(forall rep.
FFIArg rep =>
(GenValue Concrete -> Eval rep) ->
result) ->
result
getMarshalBasicValArg (FFIWord _ s) f = withWordType s \(_ :: p t) ->
f @t $ fmap (fromInteger . bvVal) . fromVWord Concrete "getMarshalBasicValArg"
getMarshalBasicValArg (FFIFloat _ _ s) f =
case s of
-- LibBF can only convert to 'Double' directly, so we do that first then
@ -329,7 +389,8 @@ getMarshalBasicValArg (FFIFloat _ _ s) f =
-- the original data was 32-bit anyways.
FFIFloat32 -> f $ pure . CFloat . double2Float . toDouble
FFIFloat64 -> f $ pure . CDouble . toDouble
where toDouble = fst . bfToDouble NearEven . bfValue . fromVFloat
where
toDouble = fst . bfToDouble NearEven . bfValue . fromVFloat
-- | Given a 'FFIBasicValType', call the callback with an unmarshalling function
-- from the 'FFIRet' type corresponding to the 'FFIBasicValType' to Cryptol
@ -358,8 +419,11 @@ withWordType FFIWord64 f = f $ Proxy @Word64
-- that takes a Cryptol value and calls its callback with the 'Storable' type
-- corresponding to the 'FFIBasicRefType'.
getMarshalBasicRefArg :: FFIBasicRefType ->
(forall a. Storable a =>
(GenValue Concrete -> (a -> IO b) -> IO b) -> c) -> c
(forall rep.
Storable rep =>
(GenValue Concrete -> (rep -> IO val) -> IO val) ->
result) ->
result
getMarshalBasicRefArg (FFIInteger _) f = f \val g ->
withInInteger' (fromVInteger val) g
getMarshalBasicRefArg FFIRational f = f \val g -> do