Make take and drop primitive instead of splitAt.

This allows us to generalize the type of `take` and simplifies
the implementations.
This commit is contained in:
Rob Dockins 2021-03-29 14:56:29 -07:00
parent a112ed8cf7
commit d3accfb042
4 changed files with 79 additions and 65 deletions

View File

@ -717,8 +717,9 @@ primitive (#) : {front, back, a} (fin front) => [front]a -> [back]a
* Splits a sequence into a pair of sequences.
* 'splitAt z = (x, y)' iff 'x # y = z'.
*/
primitive splitAt : {front, back, a} (fin front) => [front + back]a
-> ([front]a, [back]a)
splitAt : {front, back, a} (fin front) => [front + back]a
-> ([front]a, [back]a)
splitAt xs = (take`{front,back} xs, drop`{front,back} xs)
/**
* Concatenates a list of sequences.
@ -745,18 +746,15 @@ primitive reverse : {n, a} (fin n) => [n]a -> [n]a
*/
primitive transpose : {rows, cols, a} [rows][cols]a -> [cols][rows]a
/**
* Select the first (left-most) 'front' elements of a sequence.
*/
take : {front, back, a} (fin front) => [front + back]a -> [front]a
take (x # _) = x
primitive take : {front, back, a} [front + back]a -> [front]a
/**
* Select all the elements after (to the right of) the 'front' elements of a sequence.
*/
drop : {front, back, a} (fin front) => [front + back]a -> [back]a
drop ((_ : [front] _) # y) = y
primitive drop : {front, back, a} (fin front) => [front + back]a -> [back]a
/**
* Drop the first (left-most) element of a sequence.

View File

@ -38,6 +38,7 @@ module Cryptol.Backend.SeqMap
, concatSeqMap
, splitSeqMap
, memoMap
, delaySeqMap
, zipSeqMap
, mapSeqMap
, mergeSeqMap
@ -130,6 +131,11 @@ dropSeqMap :: Integer -> SeqMap sym a -> SeqMap sym a
dropSeqMap 0 xs = xs
dropSeqMap n xs = IndexSeqMap $ \i -> lookupSeqMap xs (i+n)
delaySeqMap :: Backend sym => sym -> SEval sym (SeqMap sym a) -> SEval sym (SeqMap sym a)
delaySeqMap sym xs =
do xs' <- sDelay sym xs
pure $ IndexSeqMap $ \i -> do m <- xs'; lookupSeqMap m i
-- | Given a sequence map, return a new sequence map that is memoized using
-- a finite map memo table.
memoMap :: Backend sym => sym -> SeqMap sym a -> SEval sym (SeqMap sym a)

View File

@ -1005,48 +1005,50 @@ joinV sym parts each a val = joinSeq sym parts each a =<< fromSeq "joinV" val
{-# INLINE splitAtV #-}
splitAtV ::
{-# INLINE takeV #-}
takeV ::
Backend sym =>
sym ->
Nat' ->
Nat' ->
TValue ->
GenValue sym ->
SEval sym (GenValue sym) ->
SEval sym (GenValue sym)
splitAtV sym front back a val =
takeV sym front back a val =
case front of
Inf -> val
Nat front' ->
case back of
Nat back' | isTBit a ->
do w <- delayWordValue sym front' (fst <$> (splitWordVal sym front' back' =<< (fromWordVal "takeV" <$> val)))
pure (VWord front' w)
Inf | isTBit a ->
do w <- delayWordValue sym front' (largeBitsVal front' . fmap fromVBit <$> (fromSeq "takeV" =<< val))
pure (VWord front' w)
_ ->
do xs <- delaySeqMap sym (fromSeq "takeV" =<< val)
pure (VSeq front' xs)
{-# INLINE dropV #-}
dropV ::
Backend sym =>
sym ->
Integer ->
Nat' ->
TValue ->
SEval sym (GenValue sym) ->
SEval sym (GenValue sym)
dropV sym front back a val =
case back of
Nat back' | isTBit a ->
do w <- delayWordValue sym back' (snd <$> (splitWordVal sym front back' =<< (fromWordVal "dropV" <$> val)))
pure (VWord back' w)
Nat rightWidth | aBit -> do
ws <- sDelay sym (splitWordVal sym leftWidth rightWidth (fromWordVal "splitAtV" val))
return $ VTuple
[ VWord leftWidth . fst <$> ws
, VWord rightWidth . snd <$> ws
]
Inf | aBit -> do
vs <- sDelay sym (fromSeq "splitAtV" val)
ls <- sDelay sym (fmap fromVBit . fst . splitSeqMap leftWidth <$> vs)
rs <- sDelay sym (snd . splitSeqMap leftWidth <$> vs)
return $ VTuple [ VWord leftWidth . largeBitsVal leftWidth <$> ls
, VStream <$> rs
]
_ -> do
vs <- sDelay sym (fromSeq "splitAtV" val)
ls <- sDelay sym (fst . splitSeqMap leftWidth <$> vs)
rs <- sDelay sym (snd . splitSeqMap leftWidth <$> vs)
return $ VTuple [ VSeq leftWidth <$> ls
, mkSeq back a <$> rs
]
where
aBit = isTBit a
leftWidth = case front of
Nat n -> n
_ -> evalPanic "splitAtV" ["invalid `front` len"]
_ ->
do xs <- delaySeqMap sym (dropSeqMap front <$> (fromSeq "dropV" =<< val))
pure $ mkSeq back a xs
{-# INLINE ecSplitV #-}
@ -2055,12 +2057,19 @@ genericPrimTable sym getEOpts =
, ("split" , {-# SCC "Prelude::split" #-}
ecSplitV sym)
, ("splitAt" , {-# SCC "Prelude::splitAt" #-}
, ("take" , {-# SCC "Preldue::take" #-}
PNumPoly \front ->
PNumPoly \back ->
PTyPoly \a ->
PStrict \x ->
PPrim $ splitAtV sym front back a x)
PNumPoly \back ->
PTyPoly \a ->
PFun \xs ->
PPrim $ takeV sym front back a xs)
, ("drop" , {-# SCC "Preldue::take" #-}
PFinPoly \front ->
PNumPoly \back ->
PTyPoly \a ->
PFun \xs ->
PPrim $ dropV sym front back a xs)
, ("reverse" , {-# SCC "Prelude::reverse" #-}
PFinPoly \_a ->

View File

@ -75,6 +75,7 @@ module Cryptol.Eval.Value
, concatSeqMap
, splitSeqMap
, memoMap
, delaySeqMap
, zipSeqMap
, mapSeqMap
@ -188,7 +189,7 @@ forceValue v = case v of
instance Backend sym => Show (GenValue sym) where
instance Show (GenValue sym) where
show v = case v of
VRecord fs -> "record:" ++ show (displayOrder fs)
VTuple xs -> "tuple:" ++ show (length xs)
@ -394,47 +395,47 @@ mkSeq len elty vals = case len of
fromVBit :: GenValue sym -> SBit sym
fromVBit val = case val of
VBit b -> b
_ -> evalPanic "fromVBit" ["not a Bit"]
_ -> evalPanic "fromVBit" ["not a Bit", show val]
-- | Extract an integer value.
fromVInteger :: GenValue sym -> SInteger sym
fromVInteger val = case val of
VInteger i -> i
_ -> evalPanic "fromVInteger" ["not an Integer"]
_ -> evalPanic "fromVInteger" ["not an Integer", show val]
-- | Extract a rational value.
fromVRational :: GenValue sym -> SRational sym
fromVRational val = case val of
VRational q -> q
_ -> evalPanic "fromVRational" ["not a Rational"]
_ -> evalPanic "fromVRational" ["not a Rational", show val]
-- | Extract a finite sequence value.
fromVSeq :: GenValue sym -> SeqMap sym (GenValue sym)
fromVSeq val = case val of
VSeq _ vs -> vs
_ -> evalPanic "fromVSeq" ["not a sequence"]
_ -> evalPanic "fromVSeq" ["not a sequence", show val]
-- | Extract a sequence.
fromSeq :: Backend sym => String -> GenValue sym -> SEval sym (SeqMap sym (GenValue sym))
fromSeq msg val = case val of
VSeq _ vs -> return vs
VStream vs -> return vs
_ -> evalPanic "fromSeq" ["not a sequence", msg]
_ -> evalPanic "fromSeq" ["not a sequence", msg, show val]
fromWordVal :: Backend sym => String -> GenValue sym -> WordValue sym
fromWordVal _msg (VWord _ wval) = wval
fromWordVal msg _ = evalPanic "fromWordVal" ["not a word value", msg]
fromWordVal msg val = evalPanic "fromWordVal" ["not a word value", msg, show val]
asIndex :: Backend sym =>
sym -> String -> TValue -> GenValue sym -> Either (SInteger sym) (WordValue sym)
asIndex _sym _msg TVInteger (VInteger i) = Left i
asIndex _sym _msg _ (VWord _ wval) = Right wval
asIndex _sym msg _ _ = evalPanic "asIndex" ["not an index value", msg]
asIndex _sym msg _ val = evalPanic "asIndex" ["not an index value", msg, show val]
-- | Extract a packed word.
fromVWord :: Backend sym => sym -> String -> GenValue sym -> SEval sym (SWord sym)
fromVWord sym _msg (VWord _ wval) = asWordVal sym wval
fromVWord _ msg _ = evalPanic "fromVWord" ["not a word", msg]
fromVWord _ msg val = evalPanic "fromVWord" ["not a word", msg, show val]
vWordLen :: Backend sym => GenValue sym -> Maybe Integer
vWordLen val = case val of
@ -456,46 +457,46 @@ fromVFun :: Backend sym => sym -> GenValue sym -> (SEval sym (GenValue sym) -> S
fromVFun sym val = case val of
VFun fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVFun" ["not a function"]
_ -> evalPanic "fromVFun" ["not a function", show val]
-- | Extract a polymorphic function from a value.
fromVPoly :: Backend sym => sym -> GenValue sym -> (TValue -> SEval sym (GenValue sym))
fromVPoly sym val = case val of
VPoly fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVPoly" ["not a polymorphic value"]
_ -> evalPanic "fromVPoly" ["not a polymorphic value", show val]
-- | Extract a polymorphic function from a value.
fromVNumPoly :: Backend sym => sym -> GenValue sym -> (Nat' -> SEval sym (GenValue sym))
fromVNumPoly sym val = case val of
VNumPoly fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVNumPoly" ["not a polymorphic value"]
_ -> evalPanic "fromVNumPoly" ["not a polymorphic value", show val]
-- | Extract a tuple from a value.
fromVTuple :: GenValue sym -> [SEval sym (GenValue sym)]
fromVTuple val = case val of
VTuple vs -> vs
_ -> evalPanic "fromVTuple" ["not a tuple"]
_ -> evalPanic "fromVTuple" ["not a tuple", show val]
-- | Extract a record from a value.
fromVRecord :: GenValue sym -> RecordMap Ident (SEval sym (GenValue sym))
fromVRecord val = case val of
VRecord fs -> fs
_ -> evalPanic "fromVRecord" ["not a record"]
_ -> evalPanic "fromVRecord" ["not a record", show val]
fromVFloat :: GenValue sym -> SFloat sym
fromVFloat val =
case val of
VFloat x -> x
_ -> evalPanic "fromVFloat" ["not a Float"]
_ -> evalPanic "fromVFloat" ["not a Float", show val]
-- | Lookup a field in a record.
lookupRecord :: Ident -> GenValue sym -> SEval sym (GenValue sym)
lookupRecord f val =
case lookupField f (fromVRecord val) of
Just x -> x
Nothing -> evalPanic "lookupRecord" ["malformed record"]
Nothing -> evalPanic "lookupRecord" ["malformed record", show val]
-- Merge and if/then/else
@ -532,7 +533,7 @@ mergeValue sym c v1 v2 =
(VRecord fs1 , VRecord fs2 ) ->
do let res = zipRecords (\_lbl -> mergeValue' sym c) fs1 fs2
case res of
Left f -> panic "Cryptol.Eval.Generic" [ "mergeValue: incompatible record values", show f ]
Left f -> panic "Cryptol.Eval.Value" [ "mergeValue: incompatible record values", show f ]
Right r -> pure (VRecord r)
(VTuple vs1 , VTuple vs2 ) | length vs1 == length vs2 ->
pure $ VTuple $ zipWith (mergeValue' sym c) vs1 vs2
@ -545,8 +546,8 @@ mergeValue sym c v1 v2 =
(VStream vs1 , VStream vs2 ) -> VStream <$> memoMap sym (mergeSeqMapVal sym c vs1 vs2)
(f1@VFun{} , f2@VFun{} ) -> lam sym $ \x -> mergeValue' sym c (fromVFun sym f1 x) (fromVFun sym f2 x)
(f1@VPoly{} , f2@VPoly{} ) -> tlam sym $ \x -> mergeValue' sym c (fromVPoly sym f1 x) (fromVPoly sym f2 x)
(_ , _ ) -> panic "Cryptol.Eval.Generic"
[ "mergeValue: incompatible values" ]
(_ , _ ) -> panic "Cryptol.Eval.Value"
[ "mergeValue: incompatible values", show v1, show v2 ]
{-# INLINE mergeSeqMapVal #-}
mergeSeqMapVal :: Backend sym =>