Make take and drop primitive instead of splitAt.

This allows us to generalize the type of `take` and simplifies
the implementations.
This commit is contained in:
Rob Dockins 2021-03-29 14:56:29 -07:00
parent a112ed8cf7
commit d3accfb042
4 changed files with 79 additions and 65 deletions

View File

@ -717,8 +717,9 @@ primitive (#) : {front, back, a} (fin front) => [front]a -> [back]a
* Splits a sequence into a pair of sequences. * Splits a sequence into a pair of sequences.
* 'splitAt z = (x, y)' iff 'x # y = z'. * 'splitAt z = (x, y)' iff 'x # y = z'.
*/ */
primitive splitAt : {front, back, a} (fin front) => [front + back]a splitAt : {front, back, a} (fin front) => [front + back]a
-> ([front]a, [back]a) -> ([front]a, [back]a)
splitAt xs = (take`{front,back} xs, drop`{front,back} xs)
/** /**
* Concatenates a list of sequences. * Concatenates a list of sequences.
@ -745,18 +746,15 @@ primitive reverse : {n, a} (fin n) => [n]a -> [n]a
*/ */
primitive transpose : {rows, cols, a} [rows][cols]a -> [cols][rows]a primitive transpose : {rows, cols, a} [rows][cols]a -> [cols][rows]a
/** /**
* Select the first (left-most) 'front' elements of a sequence. * Select the first (left-most) 'front' elements of a sequence.
*/ */
take : {front, back, a} (fin front) => [front + back]a -> [front]a primitive take : {front, back, a} [front + back]a -> [front]a
take (x # _) = x
/** /**
* Select all the elements after (to the right of) the 'front' elements of a sequence. * Select all the elements after (to the right of) the 'front' elements of a sequence.
*/ */
drop : {front, back, a} (fin front) => [front + back]a -> [back]a primitive drop : {front, back, a} (fin front) => [front + back]a -> [back]a
drop ((_ : [front] _) # y) = y
/** /**
* Drop the first (left-most) element of a sequence. * Drop the first (left-most) element of a sequence.

View File

@ -38,6 +38,7 @@ module Cryptol.Backend.SeqMap
, concatSeqMap , concatSeqMap
, splitSeqMap , splitSeqMap
, memoMap , memoMap
, delaySeqMap
, zipSeqMap , zipSeqMap
, mapSeqMap , mapSeqMap
, mergeSeqMap , mergeSeqMap
@ -130,6 +131,11 @@ dropSeqMap :: Integer -> SeqMap sym a -> SeqMap sym a
dropSeqMap 0 xs = xs dropSeqMap 0 xs = xs
dropSeqMap n xs = IndexSeqMap $ \i -> lookupSeqMap xs (i+n) dropSeqMap n xs = IndexSeqMap $ \i -> lookupSeqMap xs (i+n)
delaySeqMap :: Backend sym => sym -> SEval sym (SeqMap sym a) -> SEval sym (SeqMap sym a)
delaySeqMap sym xs =
do xs' <- sDelay sym xs
pure $ IndexSeqMap $ \i -> do m <- xs'; lookupSeqMap m i
-- | Given a sequence map, return a new sequence map that is memoized using -- | Given a sequence map, return a new sequence map that is memoized using
-- a finite map memo table. -- a finite map memo table.
memoMap :: Backend sym => sym -> SeqMap sym a -> SEval sym (SeqMap sym a) memoMap :: Backend sym => sym -> SeqMap sym a -> SEval sym (SeqMap sym a)

View File

@ -1005,48 +1005,50 @@ joinV sym parts each a val = joinSeq sym parts each a =<< fromSeq "joinV" val
{-# INLINE splitAtV #-} {-# INLINE takeV #-}
splitAtV :: takeV ::
Backend sym => Backend sym =>
sym -> sym ->
Nat' -> Nat' ->
Nat' -> Nat' ->
TValue -> TValue ->
GenValue sym -> SEval sym (GenValue sym) ->
SEval sym (GenValue sym) SEval sym (GenValue sym)
splitAtV sym front back a val = takeV sym front back a val =
case front of
Inf -> val
Nat front' ->
case back of
Nat back' | isTBit a ->
do w <- delayWordValue sym front' (fst <$> (splitWordVal sym front' back' =<< (fromWordVal "takeV" <$> val)))
pure (VWord front' w)
Inf | isTBit a ->
do w <- delayWordValue sym front' (largeBitsVal front' . fmap fromVBit <$> (fromSeq "takeV" =<< val))
pure (VWord front' w)
_ ->
do xs <- delaySeqMap sym (fromSeq "takeV" =<< val)
pure (VSeq front' xs)
{-# INLINE dropV #-}
dropV ::
Backend sym =>
sym ->
Integer ->
Nat' ->
TValue ->
SEval sym (GenValue sym) ->
SEval sym (GenValue sym)
dropV sym front back a val =
case back of case back of
Nat back' | isTBit a ->
do w <- delayWordValue sym back' (snd <$> (splitWordVal sym front back' =<< (fromWordVal "dropV" <$> val)))
pure (VWord back' w)
Nat rightWidth | aBit -> do _ ->
ws <- sDelay sym (splitWordVal sym leftWidth rightWidth (fromWordVal "splitAtV" val)) do xs <- delaySeqMap sym (dropSeqMap front <$> (fromSeq "dropV" =<< val))
return $ VTuple pure $ mkSeq back a xs
[ VWord leftWidth . fst <$> ws
, VWord rightWidth . snd <$> ws
]
Inf | aBit -> do
vs <- sDelay sym (fromSeq "splitAtV" val)
ls <- sDelay sym (fmap fromVBit . fst . splitSeqMap leftWidth <$> vs)
rs <- sDelay sym (snd . splitSeqMap leftWidth <$> vs)
return $ VTuple [ VWord leftWidth . largeBitsVal leftWidth <$> ls
, VStream <$> rs
]
_ -> do
vs <- sDelay sym (fromSeq "splitAtV" val)
ls <- sDelay sym (fst . splitSeqMap leftWidth <$> vs)
rs <- sDelay sym (snd . splitSeqMap leftWidth <$> vs)
return $ VTuple [ VSeq leftWidth <$> ls
, mkSeq back a <$> rs
]
where
aBit = isTBit a
leftWidth = case front of
Nat n -> n
_ -> evalPanic "splitAtV" ["invalid `front` len"]
{-# INLINE ecSplitV #-} {-# INLINE ecSplitV #-}
@ -2055,12 +2057,19 @@ genericPrimTable sym getEOpts =
, ("split" , {-# SCC "Prelude::split" #-} , ("split" , {-# SCC "Prelude::split" #-}
ecSplitV sym) ecSplitV sym)
, ("splitAt" , {-# SCC "Prelude::splitAt" #-} , ("take" , {-# SCC "Preldue::take" #-}
PNumPoly \front -> PNumPoly \front ->
PNumPoly \back -> PNumPoly \back ->
PTyPoly \a -> PTyPoly \a ->
PStrict \x -> PFun \xs ->
PPrim $ splitAtV sym front back a x) PPrim $ takeV sym front back a xs)
, ("drop" , {-# SCC "Preldue::take" #-}
PFinPoly \front ->
PNumPoly \back ->
PTyPoly \a ->
PFun \xs ->
PPrim $ dropV sym front back a xs)
, ("reverse" , {-# SCC "Prelude::reverse" #-} , ("reverse" , {-# SCC "Prelude::reverse" #-}
PFinPoly \_a -> PFinPoly \_a ->

View File

@ -75,6 +75,7 @@ module Cryptol.Eval.Value
, concatSeqMap , concatSeqMap
, splitSeqMap , splitSeqMap
, memoMap , memoMap
, delaySeqMap
, zipSeqMap , zipSeqMap
, mapSeqMap , mapSeqMap
@ -188,7 +189,7 @@ forceValue v = case v of
instance Backend sym => Show (GenValue sym) where instance Show (GenValue sym) where
show v = case v of show v = case v of
VRecord fs -> "record:" ++ show (displayOrder fs) VRecord fs -> "record:" ++ show (displayOrder fs)
VTuple xs -> "tuple:" ++ show (length xs) VTuple xs -> "tuple:" ++ show (length xs)
@ -394,47 +395,47 @@ mkSeq len elty vals = case len of
fromVBit :: GenValue sym -> SBit sym fromVBit :: GenValue sym -> SBit sym
fromVBit val = case val of fromVBit val = case val of
VBit b -> b VBit b -> b
_ -> evalPanic "fromVBit" ["not a Bit"] _ -> evalPanic "fromVBit" ["not a Bit", show val]
-- | Extract an integer value. -- | Extract an integer value.
fromVInteger :: GenValue sym -> SInteger sym fromVInteger :: GenValue sym -> SInteger sym
fromVInteger val = case val of fromVInteger val = case val of
VInteger i -> i VInteger i -> i
_ -> evalPanic "fromVInteger" ["not an Integer"] _ -> evalPanic "fromVInteger" ["not an Integer", show val]
-- | Extract a rational value. -- | Extract a rational value.
fromVRational :: GenValue sym -> SRational sym fromVRational :: GenValue sym -> SRational sym
fromVRational val = case val of fromVRational val = case val of
VRational q -> q VRational q -> q
_ -> evalPanic "fromVRational" ["not a Rational"] _ -> evalPanic "fromVRational" ["not a Rational", show val]
-- | Extract a finite sequence value. -- | Extract a finite sequence value.
fromVSeq :: GenValue sym -> SeqMap sym (GenValue sym) fromVSeq :: GenValue sym -> SeqMap sym (GenValue sym)
fromVSeq val = case val of fromVSeq val = case val of
VSeq _ vs -> vs VSeq _ vs -> vs
_ -> evalPanic "fromVSeq" ["not a sequence"] _ -> evalPanic "fromVSeq" ["not a sequence", show val]
-- | Extract a sequence. -- | Extract a sequence.
fromSeq :: Backend sym => String -> GenValue sym -> SEval sym (SeqMap sym (GenValue sym)) fromSeq :: Backend sym => String -> GenValue sym -> SEval sym (SeqMap sym (GenValue sym))
fromSeq msg val = case val of fromSeq msg val = case val of
VSeq _ vs -> return vs VSeq _ vs -> return vs
VStream vs -> return vs VStream vs -> return vs
_ -> evalPanic "fromSeq" ["not a sequence", msg] _ -> evalPanic "fromSeq" ["not a sequence", msg, show val]
fromWordVal :: Backend sym => String -> GenValue sym -> WordValue sym fromWordVal :: Backend sym => String -> GenValue sym -> WordValue sym
fromWordVal _msg (VWord _ wval) = wval fromWordVal _msg (VWord _ wval) = wval
fromWordVal msg _ = evalPanic "fromWordVal" ["not a word value", msg] fromWordVal msg val = evalPanic "fromWordVal" ["not a word value", msg, show val]
asIndex :: Backend sym => asIndex :: Backend sym =>
sym -> String -> TValue -> GenValue sym -> Either (SInteger sym) (WordValue sym) sym -> String -> TValue -> GenValue sym -> Either (SInteger sym) (WordValue sym)
asIndex _sym _msg TVInteger (VInteger i) = Left i asIndex _sym _msg TVInteger (VInteger i) = Left i
asIndex _sym _msg _ (VWord _ wval) = Right wval asIndex _sym _msg _ (VWord _ wval) = Right wval
asIndex _sym msg _ _ = evalPanic "asIndex" ["not an index value", msg] asIndex _sym msg _ val = evalPanic "asIndex" ["not an index value", msg, show val]
-- | Extract a packed word. -- | Extract a packed word.
fromVWord :: Backend sym => sym -> String -> GenValue sym -> SEval sym (SWord sym) fromVWord :: Backend sym => sym -> String -> GenValue sym -> SEval sym (SWord sym)
fromVWord sym _msg (VWord _ wval) = asWordVal sym wval fromVWord sym _msg (VWord _ wval) = asWordVal sym wval
fromVWord _ msg _ = evalPanic "fromVWord" ["not a word", msg] fromVWord _ msg val = evalPanic "fromVWord" ["not a word", msg, show val]
vWordLen :: Backend sym => GenValue sym -> Maybe Integer vWordLen :: Backend sym => GenValue sym -> Maybe Integer
vWordLen val = case val of vWordLen val = case val of
@ -456,46 +457,46 @@ fromVFun :: Backend sym => sym -> GenValue sym -> (SEval sym (GenValue sym) -> S
fromVFun sym val = case val of fromVFun sym val = case val of
VFun fnstk f -> VFun fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x) \x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVFun" ["not a function"] _ -> evalPanic "fromVFun" ["not a function", show val]
-- | Extract a polymorphic function from a value. -- | Extract a polymorphic function from a value.
fromVPoly :: Backend sym => sym -> GenValue sym -> (TValue -> SEval sym (GenValue sym)) fromVPoly :: Backend sym => sym -> GenValue sym -> (TValue -> SEval sym (GenValue sym))
fromVPoly sym val = case val of fromVPoly sym val = case val of
VPoly fnstk f -> VPoly fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x) \x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVPoly" ["not a polymorphic value"] _ -> evalPanic "fromVPoly" ["not a polymorphic value", show val]
-- | Extract a polymorphic function from a value. -- | Extract a polymorphic function from a value.
fromVNumPoly :: Backend sym => sym -> GenValue sym -> (Nat' -> SEval sym (GenValue sym)) fromVNumPoly :: Backend sym => sym -> GenValue sym -> (Nat' -> SEval sym (GenValue sym))
fromVNumPoly sym val = case val of fromVNumPoly sym val = case val of
VNumPoly fnstk f -> VNumPoly fnstk f ->
\x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x) \x -> sModifyCallStack sym (\stk -> combineCallStacks stk fnstk) (f x)
_ -> evalPanic "fromVNumPoly" ["not a polymorphic value"] _ -> evalPanic "fromVNumPoly" ["not a polymorphic value", show val]
-- | Extract a tuple from a value. -- | Extract a tuple from a value.
fromVTuple :: GenValue sym -> [SEval sym (GenValue sym)] fromVTuple :: GenValue sym -> [SEval sym (GenValue sym)]
fromVTuple val = case val of fromVTuple val = case val of
VTuple vs -> vs VTuple vs -> vs
_ -> evalPanic "fromVTuple" ["not a tuple"] _ -> evalPanic "fromVTuple" ["not a tuple", show val]
-- | Extract a record from a value. -- | Extract a record from a value.
fromVRecord :: GenValue sym -> RecordMap Ident (SEval sym (GenValue sym)) fromVRecord :: GenValue sym -> RecordMap Ident (SEval sym (GenValue sym))
fromVRecord val = case val of fromVRecord val = case val of
VRecord fs -> fs VRecord fs -> fs
_ -> evalPanic "fromVRecord" ["not a record"] _ -> evalPanic "fromVRecord" ["not a record", show val]
fromVFloat :: GenValue sym -> SFloat sym fromVFloat :: GenValue sym -> SFloat sym
fromVFloat val = fromVFloat val =
case val of case val of
VFloat x -> x VFloat x -> x
_ -> evalPanic "fromVFloat" ["not a Float"] _ -> evalPanic "fromVFloat" ["not a Float", show val]
-- | Lookup a field in a record. -- | Lookup a field in a record.
lookupRecord :: Ident -> GenValue sym -> SEval sym (GenValue sym) lookupRecord :: Ident -> GenValue sym -> SEval sym (GenValue sym)
lookupRecord f val = lookupRecord f val =
case lookupField f (fromVRecord val) of case lookupField f (fromVRecord val) of
Just x -> x Just x -> x
Nothing -> evalPanic "lookupRecord" ["malformed record"] Nothing -> evalPanic "lookupRecord" ["malformed record", show val]
-- Merge and if/then/else -- Merge and if/then/else
@ -532,7 +533,7 @@ mergeValue sym c v1 v2 =
(VRecord fs1 , VRecord fs2 ) -> (VRecord fs1 , VRecord fs2 ) ->
do let res = zipRecords (\_lbl -> mergeValue' sym c) fs1 fs2 do let res = zipRecords (\_lbl -> mergeValue' sym c) fs1 fs2
case res of case res of
Left f -> panic "Cryptol.Eval.Generic" [ "mergeValue: incompatible record values", show f ] Left f -> panic "Cryptol.Eval.Value" [ "mergeValue: incompatible record values", show f ]
Right r -> pure (VRecord r) Right r -> pure (VRecord r)
(VTuple vs1 , VTuple vs2 ) | length vs1 == length vs2 -> (VTuple vs1 , VTuple vs2 ) | length vs1 == length vs2 ->
pure $ VTuple $ zipWith (mergeValue' sym c) vs1 vs2 pure $ VTuple $ zipWith (mergeValue' sym c) vs1 vs2
@ -545,8 +546,8 @@ mergeValue sym c v1 v2 =
(VStream vs1 , VStream vs2 ) -> VStream <$> memoMap sym (mergeSeqMapVal sym c vs1 vs2) (VStream vs1 , VStream vs2 ) -> VStream <$> memoMap sym (mergeSeqMapVal sym c vs1 vs2)
(f1@VFun{} , f2@VFun{} ) -> lam sym $ \x -> mergeValue' sym c (fromVFun sym f1 x) (fromVFun sym f2 x) (f1@VFun{} , f2@VFun{} ) -> lam sym $ \x -> mergeValue' sym c (fromVFun sym f1 x) (fromVFun sym f2 x)
(f1@VPoly{} , f2@VPoly{} ) -> tlam sym $ \x -> mergeValue' sym c (fromVPoly sym f1 x) (fromVPoly sym f2 x) (f1@VPoly{} , f2@VPoly{} ) -> tlam sym $ \x -> mergeValue' sym c (fromVPoly sym f1 x) (fromVPoly sym f2 x)
(_ , _ ) -> panic "Cryptol.Eval.Generic" (_ , _ ) -> panic "Cryptol.Eval.Value"
[ "mergeValue: incompatible values" ] [ "mergeValue: incompatible values", show v1, show v2 ]
{-# INLINE mergeSeqMapVal #-} {-# INLINE mergeSeqMapVal #-}
mergeSeqMapVal :: Backend sym => mergeSeqMapVal :: Backend sym =>