Only add subtraction in special cases during goal rewriting

It's only OK to subtract something if you know that it's finite.  This change
adds a fairly conservative check when rewriting `a + b = r` to `a = r - b` that
`b` be finite.
This commit is contained in:
Trevor Elliott 2015-08-13 11:15:01 -07:00
parent dadc5e1781
commit 9fa56160f8
2 changed files with 90 additions and 42 deletions

View File

@ -36,7 +36,8 @@ import qualified Cryptol.TypeCheck.Solver.Numeric.Simplify1 as Num
import qualified Cryptol.TypeCheck.Solver.Numeric.SimplifyExpr as Num
import qualified Cryptol.TypeCheck.Solver.CrySAT as Num
import Cryptol.TypeCheck.Solver.CrySAT (debugBlock, DebugLog(..))
import Cryptol.TypeCheck.Solver.Simplify (tryRewritePropAsSubst)
import Cryptol.TypeCheck.Solver.Simplify
(Fins,filterFins,tryRewritePropAsSubst)
import Cryptol.Utils.PP (text)
import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Misc(anyJust)
@ -317,7 +318,7 @@ computeImprovements :: Num.Solver -> [Goal] -> IO (Either [Goal] Subst)
computeImprovements s gs
-- Find things of the form: `x = t`. We might do some rewriting to put
-- it in this form, if needed.
| (x,t) : _ <- mapMaybe (tryRewritePropAsSubst . goal) gs =
| (x,t) : _ <- mapMaybe (improveByDefn (filterFins gs)) gs =
do let su = singleSubst x t
debugLog s ("Improve by definition: " ++ show (pp su))
return (Right su)
@ -340,7 +341,10 @@ computeImprovements s gs
return (Left bad)
improveByDefn :: Fins -> Goal -> Maybe (TVar,Type)
improveByDefn fins Goal { .. } =
do (var,ty) <- tryRewritePropAsSubst fins goal
return (var,simpType ty)

View File

@ -2,58 +2,95 @@
{-# LANGUAGE MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Simplify (
Fins, filterFins,
tryRewritePropAsSubst
) where
import Cryptol.Prims.Syntax (TFun(..))
import Cryptol.TypeCheck.AST (Type(..),Prop,TVar,pIsEq,isFreeTV,TCon(..))
import Cryptol.TypeCheck.AST (Type(..),Prop,TVar,pIsEq,isFreeTV,TCon(..),pIsFin)
import Cryptol.TypeCheck.InferTypes (Goal(..))
import Cryptol.TypeCheck.Subst (fvs)
import Control.Monad (msum,guard,mzero)
import Data.Function (on)
import Data.List (sortBy)
import Data.Maybe (catMaybes,listToMaybe)
import qualified Data.Set as Set
-- | Type variables that are known to have a `fin` constraint. This set is used
-- to justify the addition of a subtraction on the rhs of an equality
-- constraint.
type Fins = Set.Set TVar
filterFins :: [Goal] -> Fins
filterFins gs = Set.unions [ fvs ty | Goal { .. } <- gs
, Just ty <- [pIsFin goal] ]
-- | When given an equality constraint, attempt to rewrite it to the form `?x =
-- ...`, by moving all occurrences of `?x` to the LHS, and any other variables
-- to the RHS. This will only work when there's only one unification variable
-- present in the prop.
tryRewritePropAsSubst :: Prop -> Maybe (TVar,Type)
tryRewritePropAsSubst p =
tryRewritePropAsSubst :: Fins -> Prop -> Maybe (TVar,Type)
tryRewritePropAsSubst fins p =
do (x,y) <- pIsEq p
-- extract the single unification variable from the prop (there can be
-- bound variables remaining)
let xfvs = fvs x
yfvs = fvs y
vars = Set.toList (Set.filter isFreeTV (Set.union xfvs yfvs))
[uvar] <- return vars
let vars = Set.toList (Set.filter isFreeTV (fvs p))
rhs <- msum [ simpleCase uvar x y yfvs
, simpleCase uvar y x xfvs
, oneSided uvar x y yfvs
, oneSided uvar y x xfvs
]
listToMaybe $ sortBy (flip compare `on` rank)
$ catMaybes [ tryRewriteEq fins var x y | var <- vars ]
return (uvar,rhs)
-- | Rank a rewrite.
rank :: (TVar,Type) -> Int
rank (_,ty) = go ty
where
go (TCon (TF TCAdd) ts) = sum (map go ts) + 1
go (TCon (TF TCSub) ts) = sum (map go ts) - 1
go (TCon (TF TCMul) ts) = sum (map go ts) + 1
go (TCon (TF TCDiv) ts) = sum (map go ts) - 1
go (TCon _ ts) = sum (map go ts)
go _ = 0
-- | Rewrite an equation with respect to a unification variable ?x, into the
-- form `?x = t`.
tryRewriteEq :: Fins -> TVar -> Type -> Type -> Maybe (TVar,Type)
tryRewriteEq fins uvar l r =
msum [ do guard (uvarTy == l && uvar `Set.notMember` rfvs)
return (uvar, r)
, do guard (uvarTy == r && uvar `Set.notMember` lfvs)
return (uvar, l)
, do guard (uvar `Set.notMember` rfvs)
ty <- rewriteLHS fins uvar l r
return (uvar,ty)
, do guard (uvar `Set.notMember` lfvs)
ty <- rewriteLHS fins uvar r l
return (uvar,ty)
]
where
-- Check for the case where l is a free variable, and the rhs doesn't mention
-- it.
simpleCase uvar l r rfvs =
do guard (TVar uvar == l && uvar `Set.notMember` rfvs)
return r
uvarTy = TVar uvar
lfvs = fvs l
rfvs = fvs r
-- | Check that a type contains only finite type variables.
allFin :: Fins -> Type -> Bool
allFin fins ty = fvs ty `Set.isSubsetOf` fins
-- Check for the case where the unification variable only occurs on one side
-- of the constraint.
oneSided uvar l r rfvs =
do guard (uvar `Set.notMember` rfvs)
rewriteLHS uvar l r
-- | Rewrite an equality until the LHS is just uvar. Return the rewritten RHS.
rewriteLHS :: TVar -> Type -> Type -> Maybe Type
rewriteLHS uvar = go
rewriteLHS :: Fins -> TVar -> Type -> Type -> Maybe Type
rewriteLHS fins uvar = go
where
go (TVar tv) rhs | tv == uvar = return rhs
@ -68,8 +105,9 @@ rewriteLHS uvar = go
-- for now, don't handle the complicated case where the variable shows up
-- multiple times in an expression
if | inX && inY -> mzero
| inX -> applyR x tf y rhs
| inY -> applyL x tf y rhs
| inX -> balanceR x tf y rhs
| inY -> balanceL x tf y rhs
| otherwise -> mzero
-- discard type synonyms, the rewriting will make them no longer apply
@ -83,16 +121,22 @@ rewriteLHS uvar = go
-- invert the type function to balance the equation, when the variable occurs
-- on the LHS of the expression `x tf y`
applyR x TCAdd y rhs = go x (TCon (TF TCSub) [rhs,y])
applyR x TCSub y rhs = go x (TCon (TF TCAdd) [rhs,y])
applyR x TCMul y rhs = go x (TCon (TF TCDiv) [rhs,y])
applyR x TCDiv y rhs = go x (TCon (TF TCMul) [rhs,y])
applyR _ _ _ _ = mzero
balanceR x TCAdd y rhs = do guardSubtract y
go x (TCon (TF TCSub) [rhs,y])
balanceR x TCSub y rhs = go x (TCon (TF TCAdd) [rhs,y])
balanceR _ _ _ _ = mzero
-- invert the type function to balance the equation, when the variable occurs
-- on the RHS of the expression `x tf y`
applyL x TCAdd y rhs = go y (TCon (TF TCSub) [rhs,x])
applyL x TCMul y rhs = go y (TCon (TF TCDiv) [rhs,x])
applyL x TCSub y rhs = go (TCon (TF TCAdd) [rhs,y]) x
applyL x TCDiv y rhs = go (TCon (TF TCMul) [rhs,y]) x
applyL _ _ _ _ = mzero
balanceL x TCAdd y rhs = do guardSubtract y
go y (TCon (TF TCSub) [rhs,x])
balanceL x TCSub y rhs = go (TCon (TF TCAdd) [rhs,y]) x
balanceL _ _ _ _ = mzero
-- guard that it's OK to subtract this type from something else.
--
-- XXX this ignores things like `min x inf` where x is finite, and just
-- assumes that it won't work.
guardSubtract ty = guard (allFin fins ty)