Reimplement comparison primitives by recursion on the TValue.

This commit is contained in:
Brian Huffman 2018-06-11 16:40:57 -07:00
parent f7b8373a4a
commit e11806f64a
3 changed files with 40 additions and 33 deletions

View File

@ -669,6 +669,12 @@ fromVInteger val = case val of
VInteger i -> i
_ -> evalPanic "fromVInteger" ["not an Integer"]
-- | Extract a finite sequence value.
fromVSeq :: GenValue b w i -> SeqMap b w i
fromVSeq val = case val of
VSeq _ vs -> vs
_ -> evalPanic "fromVSeq" ["not a sequence"]
-- | Extract a sequence.
fromSeq :: forall b w i. BitWord b w i => String -> GenValue b w i -> Eval (SeqMap b w i)
fromSeq msg val = case val of

View File

@ -465,39 +465,40 @@ cmpValue :: BitWord b w i
=> (b -> b -> Eval a -> Eval a)
-> (w -> w -> Eval a -> Eval a)
-> (i -> i -> Eval a -> Eval a)
-> (GenValue b w i -> GenValue b w i -> Eval a -> Eval a)
-> (TValue -> GenValue b w i -> GenValue b w i -> Eval a -> Eval a)
cmpValue fb fw fi = cmp
where
cmp v1 v2 k =
case (v1, v2) of
(VRecord fs1, VRecord fs2) -> let vals = map snd . sortBy (comparing fst)
in cmpValues (vals fs1) (vals fs2) k
(VTuple vs1 , VTuple vs2 ) -> cmpValues vs1 vs2 k
(VBit b1 , VBit b2 ) -> fb b1 b2 k
(VInteger i1, VInteger i2) -> fi i1 i2 k
(VWord _ w1 , VWord _ w2 ) -> join (fw <$> (asWordVal =<< w1)
<*> (asWordVal =<< w2)
<*> return k)
(VSeq n vs1 , VSeq _ vs2 ) -> cmpValues (enumerateSeqMap n vs1)
(enumerateSeqMap n vs2) k
(VStream {} , VStream {} ) -> panic "Cryptol.Prims.Value.cmpValue"
[ "Infinite streams are not comparable" ]
(VFun {} , VFun {} ) -> panic "Cryptol.Prims.Value.cmpValue"
[ "Functions are not comparable" ]
(VPoly {} , VPoly {} ) -> panic "Cryptol.Prims.Value.cmpValue"
[ "Polymorphic values are not comparable" ]
(_ , _ ) -> panic "Cryptol.Prims.Value.cmpValue"
[ "type mismatch" ]
cmp ty v1 v2 k =
case ty of
TVBit -> fb (fromVBit v1) (fromVBit v2) k
TVInteger -> fi (fromVInteger v1) (fromVInteger v2) k
TVSeq n t
| isTBit t -> do w1 <- fromVWord "cmpValue" v1
w2 <- fromVWord "cmpValue" v2
fw w1 w2 k
| otherwise -> cmpValues (repeat t)
(enumerateSeqMap n (fromVSeq v1))
(enumerateSeqMap n (fromVSeq v2)) k
TVStream _ -> panic "Cryptol.Prims.Value.cmpValue"
[ "Infinite streams are not comparable" ]
TVFun _ _ -> panic "Cryptol.Prims.Value.cmpValue"
[ "Functions are not comparable" ]
TVTuple tys -> cmpValues tys (fromVTuple v1) (fromVTuple v2) k
TVRec fields -> do let vals = map snd . sortBy (comparing fst)
let tys = vals fields
cmpValues tys
(vals (fromVRecord v1))
(vals (fromVRecord v2)) k
cmpValues (x1 : xs1) (x2 : xs2) k = do
x1' <- x1
x2' <- x2
cmp x1' x2' (cmpValues xs1 xs2 k)
cmpValues _ _ k = k
cmpValues (t : ts) (x1 : xs1) (x2 : xs2) k =
do x1' <- x1
x2' <- x2
cmp t x1' x2' (cmpValues ts xs1 xs2 k)
cmpValues _ _ _ k = k
lexCompare :: Value -> Value -> Eval Ordering
lexCompare a b = cmpValue op opw op a b (return EQ)
lexCompare :: TValue -> Value -> Value -> Eval Ordering
lexCompare ty a b = cmpValue op opw op ty a b (return EQ)
where
opw :: BV -> BV -> Eval Ordering -> Eval Ordering
opw x y k = op (bvVal x) (bvVal y) k
@ -507,8 +508,8 @@ lexCompare a b = cmpValue op opw op a b (return EQ)
EQ -> k
cmp -> return cmp
signedLexCompare :: Value -> Value -> Eval Ordering
signedLexCompare a b = cmpValue opb opw opi a b (return EQ)
signedLexCompare :: TValue -> Value -> Value -> Eval Ordering
signedLexCompare ty a b = cmpValue opb opw opi ty a b (return EQ)
where
opb :: Bool -> Bool -> Eval Ordering -> Eval Ordering
opb _x _y _k = panic "signedLexCompare"
@ -525,11 +526,11 @@ signedLexCompare a b = cmpValue opb opw opi a b (return EQ)
-- | Process two elements based on their lexicographic ordering.
cmpOrder :: String -> (Ordering -> Bool) -> Binary Bool BV Integer
cmpOrder _nm op _ty l r = VBit . op <$> lexCompare l r
cmpOrder _nm op ty l r = VBit . op <$> lexCompare ty l r
-- | Process two elements based on their lexicographic ordering, using signed comparisons
signedCmpOrder :: String -> (Ordering -> Bool) -> Binary Bool BV Integer
signedCmpOrder _nm op _ty l r = VBit . op <$> signedLexCompare l r
signedCmpOrder _nm op ty l r = VBit . op <$> signedLexCompare ty l r
-- Signed arithmetic -----------------------------------------------------------

View File

@ -495,7 +495,7 @@ cmpBinary :: (SBool -> SBool -> Eval SBool -> Eval SBool)
-> (SWord -> SWord -> Eval SBool -> Eval SBool)
-> (SInteger -> SInteger -> Eval SBool -> Eval SBool)
-> SBool -> Binary SBool SWord SInteger
cmpBinary fb fw fi b _ty v1 v2 = VBit <$> cmpValue fb fw fi v1 v2 (return b)
cmpBinary fb fw fi b ty v1 v2 = VBit <$> cmpValue fb fw fi ty v1 v2 (return b)
-- Signed arithmetic -----------------------------------------------------------