diff --git a/cryptol.cabal b/cryptol.cabal index d2a6be02..abd18e48 100644 --- a/cryptol.cabal +++ b/cryptol.cabal @@ -122,6 +122,7 @@ library Cryptol.TypeCheck.Solver.Class, Cryptol.TypeCheck.Solver.Selector, Cryptol.TypeCheck.Solver.Utils, + Cryptol.TypeCheck.Solver.Simplify, Cryptol.TypeCheck.Solver.CrySAT, Cryptol.TypeCheck.Solver.Numeric.AST, diff --git a/src/Cryptol/TypeCheck/Solve.hs b/src/Cryptol/TypeCheck/Solve.hs index 4ebbe96d..9a5af2ad 100644 --- a/src/Cryptol/TypeCheck/Solve.hs +++ b/src/Cryptol/TypeCheck/Solve.hs @@ -21,6 +21,7 @@ module Cryptol.TypeCheck.Solve import Cryptol.Parser.AST(LQName, thing) import Cryptol.Parser.Position (emptyRange) +import Cryptol.TypeCheck.PP(pp) import Cryptol.TypeCheck.AST import Cryptol.TypeCheck.Monad import Cryptol.TypeCheck.Subst @@ -35,6 +36,7 @@ import qualified Cryptol.TypeCheck.Solver.Numeric.Simplify1 as Num import qualified Cryptol.TypeCheck.Solver.Numeric.SimplifyExpr as Num import qualified Cryptol.TypeCheck.Solver.CrySAT as Num import Cryptol.TypeCheck.Solver.CrySAT (debugBlock, DebugLog(..)) +import Cryptol.TypeCheck.Solver.Simplify (tryRewritePropAsSubst) import Cryptol.Utils.PP (text) import Cryptol.Utils.Panic(panic) import Cryptol.Utils.Misc(anyJust) @@ -180,6 +182,8 @@ numericRight g = case Num.exportProp (goal g) of Nothing -> Left g + + {- Constraints and satisfiability: 1. [Satisfiable] A collection of constraints is _satisfiable_, if there is an @@ -237,7 +241,7 @@ simpGoals' s gs0 = go emptySubst [] (wellFormed gs0 ++ gs0) Left err -> return (Left err, su) Right impSu -> let (unchanged,changed) = - partitionEithers (map (applyImp su) gs3) + partitionEithers (map (applyImp impSu) gs3) new = wellFormed changed in go (impSu @@ su) unchanged (new ++ changed) @@ -311,7 +315,12 @@ solveNumerics s consultGs solveGs = computeImprovements :: Num.Solver -> [Goal] -> IO (Either [Goal] Subst) computeImprovements s gs - | (x,t) : _ <- mapMaybe improveByDefn gs = return (Right (singleSubst x t)) + -- Find things of the form: `x = t`. We might do some rewriting to put + -- it in this form, if needed. + | (x,t) : _ <- mapMaybe (tryRewritePropAsSubst . goal) gs = + do let su = singleSubst x t + debugLog s ("Improve by definition: " ++ show (pp su)) + return (Right su) | otherwise = debugBlock s "Computing improvements" $ do let nums = [ g | Right g <- map numericRight gs ] @@ -333,23 +342,6 @@ computeImprovements s gs -{- | If we see an equation: `?x = e`, and: - * ?x is a unification variable - * `e` is "zonked" (substitution is fully applied) - * ?x does not appear in `e`. - then, we can improve `?x` to `e`. --} -improveByDefn :: Goal -> Maybe (TVar, Type) -improveByDefn g = - do res <- pIsEq (goal g) - case res of - (TVar x, t) -> tryToBind x t - (t, TVar x) -> tryToBind x t - _ -> Nothing - where - tryToBind x t = - do guard (isFreeTV x && not (x `Set.member` fvs t)) - return (x,t) -- | Import an improving substitutin (i.e., a bunch of equations) diff --git a/src/Cryptol/TypeCheck/Solver/Simplify.hs b/src/Cryptol/TypeCheck/Solver/Simplify.hs new file mode 100644 index 00000000..699f5700 --- /dev/null +++ b/src/Cryptol/TypeCheck/Solver/Simplify.hs @@ -0,0 +1,98 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE MultiWayIf #-} + +module Cryptol.TypeCheck.Solver.Simplify ( + tryRewritePropAsSubst + ) where + + +import Cryptol.Prims.Syntax (TFun(..)) +import Cryptol.TypeCheck.AST (Type(..),Prop,TVar,pIsEq,isFreeTV,TCon(..)) +import Cryptol.TypeCheck.Subst (fvs) + +import Control.Monad (msum,guard,mzero) +import qualified Data.Set as Set + + +-- | When given an equality constraint, attempt to rewrite it to the form `?x = +-- ...`, by moving all occurrences of `?x` to the LHS, and any other variables +-- to the RHS. This will only work when there's only one unification variable +-- present in the prop. +tryRewritePropAsSubst :: Prop -> Maybe (TVar,Type) +tryRewritePropAsSubst p = + do (x,y) <- pIsEq p + + -- extract the single unification variable from the prop (there can be + -- bound variables remaining) + let xfvs = fvs x + yfvs = fvs y + vars = Set.toList (Set.filter isFreeTV (Set.union xfvs yfvs)) + [uvar] <- return vars + + rhs <- msum [ simpleCase uvar x y yfvs + , simpleCase uvar y x xfvs + , oneSided uvar x y yfvs + , oneSided uvar y x xfvs + ] + + return (uvar,rhs) + + where + + -- Check for the case where l is a free variable, and the rhs doesn't mention + -- it. + simpleCase uvar l r rfvs = + do guard (TVar uvar == l && uvar `Set.notMember` rfvs) + return r + + -- Check for the case where the unification variable only occurs on one side + -- of the constraint. + oneSided uvar l r rfvs = + do guard (uvar `Set.notMember` rfvs) + rewriteLHS uvar l r + +-- | Rewrite an equality until the LHS is just uvar. Return the rewritten RHS. +rewriteLHS :: TVar -> Type -> Type -> Maybe Type +rewriteLHS uvar = go + where + + go (TVar tv) rhs | tv == uvar = return rhs + + go (TCon (TF tf) [x,y]) rhs = + do let xfvs = fvs x + yfvs = fvs y + + inX = Set.member uvar xfvs + inY = Set.member uvar yfvs + + -- for now, don't handle the complicated case where the variable shows up + -- multiple times in an expression + if | inX && inY -> mzero + | inX -> applyR x tf y rhs + | inY -> applyL x tf y rhs + + + -- discard type synonyms, the rewriting will make them no longer apply + go (TUser _ _ l) rhs = + go l rhs + + -- records won't work here. + go _ _ = + mzero + + + -- invert the type function to balance the equation, when the variable occurs + -- on the LHS of the expression `x tf y` + applyR x TCAdd y rhs = go x (TCon (TF TCSub) [rhs,y]) + applyR x TCSub y rhs = go x (TCon (TF TCAdd) [rhs,y]) + applyR x TCMul y rhs = go x (TCon (TF TCDiv) [rhs,y]) + applyR x TCDiv y rhs = go x (TCon (TF TCMul) [rhs,y]) + applyR _ _ _ _ = mzero + + -- invert the type function to balance the equation, when the variable occurs + -- on the RHS of the expression `x tf y` + applyL x TCAdd y rhs = go y (TCon (TF TCSub) [rhs,x]) + applyL x TCMul y rhs = go y (TCon (TF TCDiv) [rhs,x]) + applyL x TCSub y rhs = go (TCon (TF TCAdd) [rhs,y]) x + applyL x TCDiv y rhs = go (TCon (TF TCMul) [rhs,y]) x + applyL _ _ _ _ = mzero diff --git a/src/Cryptol/TypeCheck/Unify.hs b/src/Cryptol/TypeCheck/Unify.hs index 3b6dcabf..7e8e43bd 100644 --- a/src/Cryptol/TypeCheck/Unify.hs +++ b/src/Cryptol/TypeCheck/Unify.hs @@ -6,6 +6,7 @@ -- Stability : provisional -- Portability : portable +{-# LANGUAGE CPP #-} {-# LANGUAGE Safe #-} {-# LANGUAGE PatternGuards, ViewPatterns #-} {-# LANGUAGE DeriveFunctor #-}