More rules and things; external solver disabled; we can at least load ChaCha

This commit is contained in:
Iavor S. Diatchki 2017-02-08 15:08:50 -08:00
parent 90c6adbfcd
commit 710355176a
17 changed files with 1064 additions and 503 deletions

View File

@ -94,6 +94,7 @@ library
Cryptol.Utils.Panic, Cryptol.Utils.Panic,
Cryptol.Utils.Debug, Cryptol.Utils.Debug,
Cryptol.Utils.Misc, Cryptol.Utils.Misc,
Cryptol.Utils.Patterns,
Cryptol.Version, Cryptol.Version,
Cryptol.ModuleSystem, Cryptol.ModuleSystem,
@ -107,6 +108,8 @@ library
Cryptol.TypeCheck, Cryptol.TypeCheck,
Cryptol.TypeCheck.Type, Cryptol.TypeCheck.Type,
Cryptol.TypeCheck.TypePat,
Cryptol.TypeCheck.SimpType,
Cryptol.TypeCheck.AST, Cryptol.TypeCheck.AST,
Cryptol.TypeCheck.Monad, Cryptol.TypeCheck.Monad,
Cryptol.TypeCheck.Infer, Cryptol.TypeCheck.Infer,
@ -130,6 +133,7 @@ library
Cryptol.TypeCheck.Solver.Utils, Cryptol.TypeCheck.Solver.Utils,
Cryptol.TypeCheck.Solver.Numeric, Cryptol.TypeCheck.Solver.Numeric,
Cryptol.TypeCheck.Solver.Improve,
Cryptol.TypeCheck.Solver.CrySAT, Cryptol.TypeCheck.Solver.CrySAT,
Cryptol.TypeCheck.Solver.Numeric.AST, Cryptol.TypeCheck.Solver.Numeric.AST,
Cryptol.TypeCheck.Solver.Numeric.ImportExport, Cryptol.TypeCheck.Solver.Numeric.ImportExport,

View File

@ -22,7 +22,7 @@ import qualified Cryptol.Parser.Names as P
import Cryptol.TypeCheck.AST hiding (tSub,tMul,tExp) import Cryptol.TypeCheck.AST hiding (tSub,tMul,tExp)
import Cryptol.TypeCheck.Monad import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Solve import Cryptol.TypeCheck.Solve
import Cryptol.TypeCheck.SimpleSolver(tSub,tMul,tExp) import Cryptol.TypeCheck.SimpType(tSub,tMul,tExp)
import Cryptol.TypeCheck.Kind(checkType,checkSchema,checkTySyn, import Cryptol.TypeCheck.Kind(checkType,checkSchema,checkTySyn,
checkNewtype) checkNewtype)
import Cryptol.TypeCheck.Instantiate import Cryptol.TypeCheck.Instantiate

View File

@ -20,6 +20,7 @@ import Cryptol.Parser.AST (Named(..))
import Cryptol.Parser.Position import Cryptol.Parser.Position
import Cryptol.TypeCheck.AST import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad hiding (withTParams) import Cryptol.TypeCheck.Monad hiding (withTParams)
import Cryptol.TypeCheck.SimpType(tRebuild)
import Cryptol.TypeCheck.Solve (simplifyAllConstraints import Cryptol.TypeCheck.Solve (simplifyAllConstraints
,wfTypeFunction) ,wfTypeFunction)
import Cryptol.Utils.PP import Cryptol.Utils.PP
@ -42,7 +43,9 @@ checkSchema (P.Forall xs ps t mb) =
do ps1 <- mapM checkProp ps do ps1 <- mapM checkProp ps
t1 <- doCheckType t (Just KType) t1 <- doCheckType t (Just KType)
return (ps1,t1) return (ps1,t1)
return (Forall xs1 ps1 t1, gs) return ( Forall xs1 (map tRebuild ps1) (tRebuild t1)
, [ g { goal = tRebuild (goal g) } | g <- gs ]
)
where where
rng = case mb of rng = case mb of
@ -59,8 +62,8 @@ checkTySyn (P.TySyn x as t) =
return r return r
return TySyn { tsName = thing x return TySyn { tsName = thing x
, tsParams = as1 , tsParams = as1
, tsConstraints = map goal gs , tsConstraints = map (tRebuild . goal) gs
, tsDef = t1 , tsDef = tRebuild t1
} }
-- | Check a newtype declaration. -- | Check a newtype declaration.
@ -89,7 +92,7 @@ checkNewtype (P.Newtype x as fs) =
checkType :: P.Type Name -> Maybe Kind -> InferM Type checkType :: P.Type Name -> Maybe Kind -> InferM Type
checkType t k = checkType t k =
do (_, t1) <- withTParams True [] $ doCheckType t k do (_, t1) <- withTParams True [] $ doCheckType t k
return t1 return (tRebuild t1)
{- | Check something with type parameters. {- | Check something with type parameters.

View File

@ -325,10 +325,7 @@ addGoals :: [Goal] -> InferM ()
addGoals gs0 = doAdd =<< simpGoals gs0 addGoals gs0 = doAdd =<< simpGoals gs0
where where
doAdd [] = return () doAdd [] = return ()
doAdd gs = doAdd gs = IM $ sets_ $ \s -> s { iCts = foldl' (flip insertGoal) (iCts s) gs }
do io $ putStrLn "Adding goals:"
io $ mapM_ putStrLn [ " " ++ show (pp (goal g)) | g <- gs ]
IM $ sets_ $ \s -> s { iCts = foldl' (flip insertGoal) (iCts s) gs }
-- | Collect the goals emitted by the given sub-computation. -- | Collect the goals emitted by the given sub-computation.

View File

@ -0,0 +1,242 @@
{-# LANGUAGE PatternGuards #-}
module Cryptol.TypeCheck.SimpType where
import Control.Applicative((<|>))
import Cryptol.TypeCheck.Type hiding
(tAdd,tSub,tMul,tDiv,tMod,tExp,tMin,tMax,tWidth,tLenFromThen,tLenFromThenTo)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.InfNat
import Control.Monad(msum,guard)
import Cryptol.TypeCheck.PP(pp)
tRebuild :: Type -> Type
tRebuild = go
where
go ty =
case ty of
TUser x xs t -> TUser x xs (go t)
TVar _ -> ty
TRec xs -> TRec [ (x,go y) | (x,y) <- xs ]
TCon tc ts ->
case (tc, map go ts) of
(TF f, ts') ->
case (f,ts') of
(TCAdd,[x,y]) -> tAdd x y
(TCSub,[x,y]) -> tSub x y
(TCMul,[x,y]) -> tMul x y
(TCExp,[x,y]) -> tExp x y
(TCDiv,[x,y]) -> tDiv x y
(TCMod,[x,y]) -> tMod x y
(TCMin,[x,y]) -> tMin x y
(TCMax,[x,y]) -> tMax x y
(TCWidth,[x]) -> tWidth x
(TCLenFromThen,[x,y,z]) -> tLenFromThen x y z
(TCLenFromThenTo,[x,y,z]) -> tLenFromThenTo x y z
_ -> TCon tc ts
(_,ts') -> TCon tc ts'
-- Normal: constants to the left
tAdd :: Type -> Type -> Type
tAdd x y
| Just t <- tOp TCAdd (total (op2 nAdd)) [x,y] = t
| tIsInf x = tInf
| tIsInf y = tInf
| Just n <- tIsNum x = addK n y
| Just n <- tIsNum y = addK n x
| Just (n,x1) <- isSumK x = addK n (tAdd x1 y)
| Just (n,y1) <- isSumK y = addK n (tAdd x y1)
| Just v <- matchMaybe (do (a,b) <- (|-|) y
guard (x == b)
return a) = v
| Just v <- matchMaybe (do (a,b) <- (|-|) x
guard (b == y)
return a) = v
| otherwise = tf2 TCAdd x y
where
isSumK t = case tNoUser t of
TCon (TF TCAdd) [ l, r ] ->
do n <- tIsNum l
return (n, r)
_ -> Nothing
addK 0 t = t
addK n t | Just (m,b) <- isSumK t = tf2 TCAdd (tNum (n + m)) b
| otherwise = tf2 TCAdd (tNum n) t
tSub :: Type -> Type -> Type
tSub x y
| Just t <- tOp TCSub (op2 nSub) [x,y] = t
| tIsInf y = tBadNumber $ TCErrorMessage "Subtraction of `inf`."
| Just 0 <- yNum = x
| Just k <- yNum
, TCon (TF TCAdd) [a,b] <- tNoUser x
, Just n <- tIsNum a = case compare k n of
EQ -> b
LT -> tf2 TCAdd (tNum (n - k)) b
GT -> tSub b (tNum (k - n))
| Just v <- matchMaybe (do (a,b) <- anAdd x
(guard (a == y) >> return b)
<|> (guard (b == y) >> return a))
= v
| Just v <- matchMaybe (do (a,b) <- (|-|) y
return (tSub (tAdd x b) a)) = v
| otherwise = tf2 TCSub x y
where
yNum = tIsNum y
-- Normal: constants to the left
tMul :: Type -> Type -> Type
tMul x y
| Just t <- tOp TCMul (total (op2 nMul)) [x,y] = t
| Just n <- tIsNum x = mulK n y
| Just n <- tIsNum y = mulK n x
| otherwise = tf2 TCMul x y
where
mulK 0 _ = tNum (0 :: Int)
mulK 1 t = t
mulK n t | TCon (TF TCMul) [a,b] <- t'
, Just a' <- tIsNat' a = case a' of
Inf -> t
Nat m -> tf2 TCMul (tNum (n * m)) b
| TCon (TF TCDiv) [a,b] <- t'
, Just b' <- tIsNum b
-- XXX: similar for a = b * k?
, n == b' = tSub a (tMod a b)
| otherwise = tf2 TCMul (tNum n) t
where t' = tNoUser t
tDiv :: Type -> Type -> Type
tDiv x y
| Just t <- tOp TCDiv (op2 nDiv) [x,y] = t
| tIsInf x = tBadNumber $ TCErrorMessage "Division of `inf`."
| Just 0 <- tIsNum y = tBadNumber $ TCErrorMessage "Division by 0."
| otherwise = tf2 TCDiv x y
tMod :: Type -> Type -> Type
tMod x y
| Just t <- tOp TCMod (op2 nMod) [x,y] = t
| tIsInf x = tBadNumber $ TCErrorMessage "Modulus of `inf`."
| Just 0 <- tIsNum x = tBadNumber $ TCErrorMessage "Modulus by 0."
| otherwise = tf2 TCMod x y
tExp :: Type -> Type -> Type
tExp x y
| Just t <- tOp TCExp (total (op2 nExp)) [x,y] = t
| Just 0 <- tIsNum y = tNum (1 :: Int)
| TCon (TF TCExp) [a,b] <- tNoUser y = tExp x (tMul a b)
| otherwise = tf2 TCExp x y
-- Normal: constants to the left
tMin :: Type -> Type -> Type
tMin x y
| Just t <- tOp TCMin (total (op2 nMin)) [x,y] = t
| Just n <- tIsNat' x = minK n y
| Just n <- tIsNat' y = minK n x
| x == y = x
-- XXX: min (k + t) t -> t
| otherwise = tf2 TCMin x y
where
minK Inf t = t
minK (Nat 0) _ = tNum (0 :: Int)
minK (Nat k) t
| TCon (TF TCAdd) [a,b] <- t'
, Just n <- tIsNum a = if k <= n then tNum k
else tAdd a (tMin (tNum (k - n)) b)
| TCon (TF TCSub) [a,b] <- t'
, Just n <- tIsNum a =
if k >= n then t else tSub a (tMax (tNum (n - k)) b)
| TCon (TF TCMin) [a,b] <- t'
, Just n <- tIsNum a = tf2 TCMin (tNum (min k n)) b
| otherwise = tf2 TCMin (tNum k) t
where t' = tNoUser t
-- Normal: constants to the left
tMax :: Type -> Type -> Type
tMax x y
| Just t <- tOp TCMax (total (op2 nMax)) [x,y] = t
| Just n <- tIsNat' x = maxK n y
| Just n <- tIsNat' y = maxK n x
| otherwise = tf2 TCMax x y
where
maxK Inf _ = tInf
maxK (Nat 0) t = t
maxK (Nat k) t
| TCon (TF TCAdd) [a,b] <- t'
, Just n <- tIsNum a = if k <= n
then t
else tMax (tNum (k - n)) b
| TCon (TF TCSub) [a,b] <- t'
, Just n <- tIsNat' a =
case n of
Inf -> t
Nat m -> if k >= m then tNum k else tSub a (tMin (tNum (m - k)) b)
| TCon (TF TCMax) [a,b] <- t'
, Just n <- tIsNum a = tf2 TCMax (tNum (max k n)) b
| otherwise = tf2 TCMax (tNum k) t
where t' = tNoUser t
tWidth :: Type -> Type
tWidth x
| Just t <- tOp TCWidth (total (op1 nWidth)) [x] = t
| otherwise = tf1 TCWidth x
tLenFromThen :: Type -> Type -> Type -> Type
tLenFromThen x y z
| Just t <- tOp TCLenFromThen (op3 nLenFromThen) [x,y,z] = t
-- XXX: rules?
| otherwise = tf3 TCLenFromThen x y z
tLenFromThenTo :: Type -> Type -> Type -> Type
tLenFromThenTo x y z
| Just t <- tOp TCLenFromThen (op3 nLenFromThen) [x,y,z] = t
| otherwise = tf3 TCLenFromThenTo x y z
total :: ([Nat'] -> Nat') -> ([Nat'] -> Maybe Nat')
total f xs = Just (f xs)
op1 :: (a -> b) -> [a] -> b
op1 f ~[x] = f x
op2 :: (a -> a -> b) -> [a] -> b
op2 f ~[x,y] = f x y
op3 :: (a -> a -> a -> b) -> [a] -> b
op3 f ~[x,y,z] = f x y z
-- | Common checks: check for error, or simple full evaluation.
tOp :: TFun -> ([Nat'] -> Maybe Nat') -> [Type] -> Maybe Type
tOp tf f ts
| Just e <- msum (map tIsError ts) = Just (tBadNumber e)
| Just xs <- mapM tIsNat' ts =
Just $ case f xs of
Nothing -> tBadNumber (err xs)
Just n -> tNat' n
| otherwise = Nothing
where
err xs = TCErrorMessage $
"Invalid applicatoin of " ++ show (pp tf) ++ " to " ++
unwords (map show xs)

View File

@ -1,10 +1,5 @@
{-# LANGUAGE PatternGuards #-} {-# LANGUAGE PatternGuards #-}
module Cryptol.TypeCheck.SimpleSolver module Cryptol.TypeCheck.SimpleSolver ( simplify , simplifyStep) where
( simplify
, simplifyStep
, tAdd, tSub, tMul, tDiv, tMod, tExp
, tMin, tMax, tWidth, tLenFromThen, tLenFromThenTo
) where
import Data.Map(Map) import Data.Map(Map)
import Control.Monad(msum) import Control.Monad(msum)
@ -20,9 +15,6 @@ import Cryptol.TypeCheck.Solver.Numeric.Interval(Interval)
import Cryptol.TypeCheck.Solver.Numeric(cryIsEqual, cryIsNotEqual, cryIsGeq) import Cryptol.TypeCheck.Solver.Numeric(cryIsEqual, cryIsNotEqual, cryIsGeq)
import Cryptol.TypeCheck.Solver.Class(solveArithInst,solveCmpInst) import Cryptol.TypeCheck.Solver.Class(solveArithInst,solveCmpInst)
type Ctxt = Map TVar Interval
simplify :: Ctxt -> Prop -> Prop simplify :: Ctxt -> Prop -> Prop
simplify ctxt p = simplify ctxt p =
case simplifyStep ctxt p of case simplifyStep ctxt p of
@ -49,196 +41,3 @@ simplifyStep ctxt prop =
_ -> Unsolved _ -> Unsolved
--------------------------------------------------------------------------------
-- Construction of type functions
-- Normal: constants to the left
tAdd :: Type -> Type -> Type
tAdd x y
| Just t <- tOp TCAdd (total (op2 nAdd)) [x,y] = t
| tIsInf x = tInf
| tIsInf y = tInf
| Just n <- tIsNum x = addK n y
| Just n <- tIsNum y = addK n x
| Just (n,x1) <- isSumK x = addK n (tAdd x1 y)
| Just (n,y1) <- isSumK y = addK n (tAdd x y1)
| otherwise = tf2 TCAdd x y
where
isSumK t = case tNoUser t of
TCon (TF TCAdd) [ l, r ] ->
do n <- tIsNum l
return (n, r)
_ -> Nothing
addK 0 t = t
addK n t | Just (m,b) <- isSumK t = tf2 TCAdd (tNum (n + m)) b
| otherwise = tf2 TCAdd (tNum n) t
tSub :: Type -> Type -> Type
tSub x y
| Just t <- tOp TCSub (op2 nSub) [x,y] = t
| tIsInf y = tBadNumber $ TCErrorMessage "Subtraction of `inf`."
| Just 0 <- yNum = x
| Just k <- yNum
, TCon (TF TCAdd) [a,b] <- tNoUser x
, Just n <- tIsNum a = case compare k n of
EQ -> b
LT -> tf2 TCAdd (tNum (n - k)) b
GT -> tSub b (tNum (k - n))
| otherwise = tf2 TCSub x y
where
yNum = tIsNum y
-- Normal: constants to the left
tMul :: Type -> Type -> Type
tMul x y
| Just t <- tOp TCMul (total (op2 nMul)) [x,y] = t
| Just n <- tIsNum x = mulK n y
| Just n <- tIsNum y = mulK n x
| otherwise = tf2 TCMul x y
where
mulK 0 _ = tNum (0 :: Int)
mulK 1 t = t
mulK n t | TCon (TF TCMul) [a,b] <- t'
, Just a' <- tIsNat' a = case a' of
Inf -> t
Nat m -> tf2 TCMul (tNum (n * m)) b
| TCon (TF TCDiv) [a,b] <- t'
, Just b' <- tIsNum b
-- XXX: similar for a = b * k?
, n == b' = tSub a (tMod a b)
| otherwise = tf2 TCMul (tNum n) t
where t' = tNoUser t
tDiv :: Type -> Type -> Type
tDiv x y
| Just t <- tOp TCDiv (op2 nDiv) [x,y] = t
| tIsInf x = tBadNumber $ TCErrorMessage "Division of `inf`."
| Just 0 <- tIsNum y = tBadNumber $ TCErrorMessage "Division by 0."
| otherwise = tf2 TCDiv x y
tMod :: Type -> Type -> Type
tMod x y
| Just t <- tOp TCMod (op2 nMod) [x,y] = t
| tIsInf x = tBadNumber $ TCErrorMessage "Modulus of `inf`."
| Just 0 <- tIsNum x = tBadNumber $ TCErrorMessage "Modulus by 0."
| otherwise = tf2 TCMod x y
tExp :: Type -> Type -> Type
tExp x y
| Just t <- tOp TCExp (total (op2 nExp)) [x,y] = t
| Just 0 <- tIsNum y = tNum (1 :: Int)
| TCon (TF TCExp) [a,b] <- tNoUser y = tExp x (tMul a b)
| otherwise = tf2 TCExp x y
-- Normal: constants to the left
tMin :: Type -> Type -> Type
tMin x y
| Just t <- tOp TCMin (total (op2 nMin)) [x,y] = t
| Just n <- tIsNat' x = minK n y
| Just n <- tIsNat' y = minK n x
| x == y = x
-- XXX: min (k + t) t -> t
| otherwise = tf2 TCMin x y
where
minK Inf t = t
minK (Nat 0) _ = tNum (0 :: Int)
minK (Nat k) t
| TCon (TF TCAdd) [a,b] <- t'
, Just n <- tIsNum a = if k <= n then tNum k
else tAdd a (tMin (tNum (k - n)) b)
| TCon (TF TCSub) [a,b] <- t'
, Just n <- tIsNum a =
if k >= n then t else tSub a (tMax (tNum (n - k)) b)
| TCon (TF TCMin) [a,b] <- t'
, Just n <- tIsNum a = tf2 TCMin (tNum (min k n)) b
| otherwise = tf2 TCMin (tNum k) t
where t' = tNoUser t
-- Normal: constants to the left
tMax :: Type -> Type -> Type
tMax x y
| Just t <- tOp TCMax (total (op2 nMax)) [x,y] = t
| Just n <- tIsNat' x = maxK n y
| Just n <- tIsNat' y = maxK n x
| otherwise = tf2 TCMax x y
where
maxK Inf _ = tInf
maxK (Nat 0) t = t
maxK (Nat k) t
| TCon (TF TCAdd) [a,b] <- t'
, Just n <- tIsNum a = if k <= n
then t
else tMax (tNum (k - n)) b
| TCon (TF TCSub) [a,b] <- t'
, Just n <- tIsNat' a =
case n of
Inf -> t
Nat m -> if k >= m then tNum k else tSub a (tMin (tNum (m - k)) b)
| TCon (TF TCMax) [a,b] <- t'
, Just n <- tIsNum a = tf2 TCMax (tNum (max k n)) b
| otherwise = tf2 TCMax (tNum k) t
where t' = tNoUser t
tWidth :: Type -> Type
tWidth x
| Just t <- tOp TCWidth (total (op1 nWidth)) [x] = t
| otherwise = tf1 TCWidth x
tLenFromThen :: Type -> Type -> Type -> Type
tLenFromThen x y z
| Just t <- tOp TCLenFromThen (op3 nLenFromThen) [x,y,z] = t
-- XXX: rules?
| otherwise = tf3 TCLenFromThen x y z
tLenFromThenTo :: Type -> Type -> Type -> Type
tLenFromThenTo x y z
| Just t <- tOp TCLenFromThen (op3 nLenFromThen) [x,y,z] = t
| otherwise = tf3 TCLenFromThenTo x y z
total :: ([Nat'] -> Nat') -> ([Nat'] -> Maybe Nat')
total f xs = Just (f xs)
op1 :: (a -> b) -> [a] -> b
op1 f ~[x] = f x
op2 :: (a -> a -> b) -> [a] -> b
op2 f ~[x,y] = f x y
op3 :: (a -> a -> a -> b) -> [a] -> b
op3 f ~[x,y,z] = f x y z
-- | Common checks: check for error, or simple full evaluation.
tOp :: TFun -> ([Nat'] -> Maybe Nat') -> [Type] -> Maybe Type
tOp tf f ts
| Just e <- msum (map tIsError ts) = Just (tBadNumber e)
| Just xs <- mapM tIsNat' ts =
Just $ case f xs of
Nothing -> tBadNumber (err xs)
Just n -> tNat' n
| otherwise = Nothing
where
err xs = TCErrorMessage $
"Invalid applicatoin of " ++ show (pp tf) ++ " to " ++
unwords (map show xs)

View File

@ -24,24 +24,29 @@ import Cryptol.TypeCheck.PP(pp)
import Cryptol.TypeCheck.AST import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Subst import Cryptol.TypeCheck.Subst
(apSubst,fvs,singleSubst, isEmptySubst, (apSubst,fvs,singleSubst, isEmptySubst, substToList,
emptySubst,Subst,listSubst, (@@), Subst, emptySubst,Subst,listSubst, (@@), Subst,
apSubstMaybe, substBinds) apSubstMaybe, substBinds)
import qualified Cryptol.TypeCheck.SimpleSolver as Simplify import qualified Cryptol.TypeCheck.SimpleSolver as Simplify
import Cryptol.TypeCheck.Solver.Types import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.Selector(tryHasGoal) import Cryptol.TypeCheck.Solver.Selector(tryHasGoal)
import Cryptol.TypeCheck.SimpType(tMax)
import Cryptol.TypeCheck.Solver.Improve(improveProp,improveProps)
import Cryptol.TypeCheck.Solver.Numeric.Interval
import qualified Cryptol.TypeCheck.Solver.Numeric.AST as Num import qualified Cryptol.TypeCheck.Solver.Numeric.AST as Num
import qualified Cryptol.TypeCheck.Solver.Numeric.ImportExport as Num import qualified Cryptol.TypeCheck.Solver.Numeric.ImportExport as Num
import qualified Cryptol.TypeCheck.Solver.Numeric.SimplifyExpr as Num import qualified Cryptol.TypeCheck.Solver.Numeric.SimplifyExpr as Num
import qualified Cryptol.TypeCheck.Solver.CrySAT as Num import qualified Cryptol.TypeCheck.Solver.CrySAT as Num
import Cryptol.TypeCheck.Solver.CrySAT (debugBlock, DebugLog(..)) import Cryptol.TypeCheck.Solver.CrySAT (debugBlock, DebugLog(..))
import Cryptol.Utils.PP (text) import Cryptol.Utils.PP (text,vcat,nest, ($$), (<+>))
import Cryptol.Utils.Panic(panic) import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Misc(anyJust) import Cryptol.Utils.Misc(anyJust)
import Cryptol.Utils.Patterns(matchMaybe)
import Control.Monad (unless, guard) import Control.Monad (unless, guard, mzero)
import Control.Applicative ((<|>))
import Data.Either(partitionEithers) import Data.Either(partitionEithers)
import Data.Maybe(catMaybes, fromMaybe) import Data.Maybe(catMaybes, fromMaybe)
import Data.Map ( Map ) import Data.Map ( Map )
@ -80,6 +85,65 @@ wfType t =
--------------------------------------------------------------------------------
quickSolverIO :: Ctxt -> [Goal] -> IO (Either Goal (Subst,[Goal]))
quickSolverIO ctxt [] = return (Right (emptySubst, []))
quickSolverIO ctxt gs =
case quickSolver ctxt gs of
Left err ->
do msg (text "Contradiction:" <+> pp (goal err))
return (Left err)
Right (su,gs') ->
do msg (vcat (map (pp . goal) gs' ++ [pp su]))
return (Right (su,gs'))
where
shAsmps = case [ pp x <+> text "in" <+> ppInterval i |
(x,i) <- Map.toList ctxt ] of
[] -> text ""
xs -> text "ASMPS:" $$ nest 2 (vcat xs $$ text "===")
msg d = return () {-putStrLn $ show (
text "quickSolver:" $$ nest 2 (vcat
[ shAsmps
, vcat (map (pp.goal) gs)
, text "==>"
, d
]))-}
quickSolver :: Ctxt -> {- ^ Facts we can know -}
[Goal] -> {- ^ Need to solve these -}
Either Goal (Subst,[Goal])
-- ^ Left: contradiciting goals,
-- Right: inferred types, unsolved goals.
quickSolver ctxt gs0 = go emptySubst [] gs0
where
go su [] [] = Right (su,[])
go su unsolved [] =
case matchMaybe (findImprovement unsolved) of
Nothing -> Right (su,unsolved)
Just (newSu, subs) -> go (newSu @@ su) [] (subs ++ apSubst su unsolved)
go su unsolved (g : gs) =
case Simplify.simplifyStep ctxt (goal g) of
Unsolvable _ -> Left g
Unsolved -> go su (g : unsolved) gs
SolvedIf subs ->
let cvt x = g { goal = x }
in go su unsolved (map cvt subs ++ gs)
-- Probably better to find more than one.
findImprovement [] = mzero
findImprovement (g : gs) =
do (su,ps) <- improveProp False ctxt (goal g)
return (su, [ g { goal = p } | p <- ps ])
<|> findImprovement gs
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
simplifyAllConstraints :: InferM () simplifyAllConstraints :: InferM ()
@ -89,11 +153,8 @@ simplifyAllConstraints =
case gs of case gs of
[] -> return () [] -> return ()
_ -> _ ->
do -- r <- curRange do solver <- getSolver
io $ putStrLn $ "simplifyAllConstraints " ++ show (length gs) (mb,su) <- io (simpGoals' solver Map.empty gs)
io $ putStrLn $ unlines [ " " ++ show (pp (goal g), pp (goalSource g)) | g <- gs ]
solver <- getSolver
(mb,su) <- io (simpGoals' solver gs)
extendSubst su extendSubst su
case mb of case mb of
Right gs1 -> addGoals gs1 Right gs1 -> addGoals gs1
@ -120,7 +181,35 @@ proveImplicationIO :: Num.Solver
-> [Goal] -- ^ Collected constraints -> [Goal] -- ^ Collected constraints
-> IO (Either Error [Warning], Subst) -> IO (Either Error [Warning], Subst)
proveImplicationIO _ _ _ _ [] [] = return (Right [], emptySubst) proveImplicationIO _ _ _ _ [] [] = return (Right [], emptySubst)
proveImplicationIO s lname varsInEnv as ps gs = proveImplicationIO s f vs ps asmps0 gs0 =
do let ctxt = assumptionIntervals Map.empty asmps
res <- quickSolverIO ctxt gs
case res of
Left err -> return (Left (UnsolvedGoal True err), emptySubst)
Right (su,[]) -> return (Right [], su)
Right (su,gs1) -> proveImplicationIO' s f vs ps asmps gs1
where
(asmps,gs) =
case matchMaybe (improveProps True Map.empty asmps0) of
Nothing -> (asmps0,gs0)
Just (newSu,newAsmps) ->
( [ TVar x =#= t | (x,t) <- substToList newSu ]
++ newAsmps
, [ g { goal = apSubst newSu (goal g) } | g <- gs0 ]
)
proveImplicationIO' :: Num.Solver
-> Name -- ^ Checking this function
-> Set TVar -- ^ These appear in the env., and we should
-- not try to default the
-> [TParam] -- ^ Type parameters
-> [Prop] -- ^ Assumed constraint
-> [Goal] -- ^ Collected constraints
-> IO (Either Error [Warning], Subst)
proveImplicationIO' s lname varsInEnv as ps gs =
debugBlock s "proveImplicationIO" $ debugBlock s "proveImplicationIO" $
do debugBlock s "assumes" (debugLog s ps) do debugBlock s "assumes" (debugLog s ps)
@ -154,7 +243,8 @@ proveImplicationIO s lname varsInEnv as ps gs =
let gs1 = filter ((`notElem` ps) . goal) gs0 let gs1 = filter ((`notElem` ps) . goal) gs0
debugLog s "3. ---------------------" debugLog s "3. ---------------------"
(mb,su1) <- simpGoals' s (scs ++ gs1) let ctxt = assumptionIntervals Map.empty ps
(mb,su1) <- simpGoals' s ctxt (scs ++ gs1)
case mb of case mb of
Left badGs -> reportUnsolved badGs (su1 @@ su) Left badGs -> reportUnsolved badGs (su1 @@ su)
@ -219,15 +309,22 @@ The plan:
9. Goto 3 9. Goto 3
-} -}
simpGoals' :: Num.Solver -> [Goal] -> IO (Either [Goal] [Goal], Subst) simpGoals' :: Num.Solver -> Ctxt -> [Goal] -> IO (Either [Goal] [Goal], Subst)
simpGoals' s gs0 = go emptySubst [] (wellFormed gs0 ++ gs0) simpGoals' s asmps gs =
do res <- quickSolverIO asmps gs
case res of
Left err -> return (Left [err], emptySubst)
Right (su,gs) -> return (Right gs, su)
simpGoals' s asmps gs0 = go emptySubst [] (wellFormed gs0 ++ gs0)
where where
-- Assumes that the well-formed constraints are themselves well-formed. -- Assumes that the well-formed constraints are themselves well-formed.
wellFormed gs = [ g { goal = p } | g <- gs, p <- wfType (goal g) ] wellFormed gs = [ g { goal = p } | g <- gs, p <- wfType (goal g) ]
go su old [] = return (Right old, su) go su old [] = return (Right old, su)
go su old gs = go su old gs =
do res <- solveConstraints s old gs do res <- solveConstraints s asmps old gs
case res of case res of
Left err -> return (Left err, su) Left err -> return (Left err, su)
Right gs2 -> Right gs2 ->
@ -264,7 +361,18 @@ want to use the fact that `x >= 1` to simplify `x >= 1` to true.
assumptionIntervals :: Ctxt -> [Prop] -> Ctxt
assumptionIntervals as ps =
case computePropIntervals as ps of
NoChange -> as
InvalidInterval {} -> as -- XXX: say something
NewIntervals bs -> Map.union bs as
solveConstraints :: Num.Solver -> solveConstraints :: Num.Solver ->
Ctxt ->
[Goal] {- We may use these, but don't try to solve, [Goal] {- We may use these, but don't try to solve,
we already tried and failed. -} -> we already tried and failed. -} ->
[Goal] {- Need to solve these -} -> [Goal] {- Need to solve these -} ->
@ -272,22 +380,25 @@ solveConstraints :: Num.Solver ->
-- ^ Left: contradiciting goals, -- ^ Left: contradiciting goals,
-- Right: goals that were not solved, or sub-goals -- Right: goals that were not solved, or sub-goals
-- for solved goals. Does not include "old" -- for solved goals. Does not include "old"
solveConstraints s otherGs gs0 = solveConstraints s asmps otherGs gs0 =
debugBlock s "Solving constraints" $ go [] gs0 debugBlock s "Solving constraints" $ go ctxt0 [] gs0
where where
go unsolved [] = ctxt0 = assumptionIntervals asmps (map goal otherGs)
go ctxt unsolved [] =
do let (cs,nums) = partitionEithers (map Num.numericRight unsolved) do let (cs,nums) = partitionEithers (map Num.numericRight unsolved)
nums' <- solveNumerics s otherNumerics nums nums' <- solveNumerics s otherNumerics nums
return (Right (cs ++ nums')) return (Right (cs ++ nums'))
go unsolved (g : gs) = go ctxt unsolved (g : gs) =
case Simplify.simplifyStep Map.empty (goal g) of case Simplify.simplifyStep ctxt (goal g) of
Unsolvable _ -> return (Left [g]) Unsolvable _ -> return (Left [g])
Unsolved -> go (g : unsolved) gs Unsolved -> go ctxt (g : unsolved) gs
SolvedIf subs -> SolvedIf subs ->
let cvt x = g { goal = x } let cvt x = g { goal = x }
in go unsolved (map cvt subs ++ gs) in go ctxt unsolved (map cvt subs ++ gs)
otherNumerics = [ g | Right g <- map Num.numericRight otherGs ] otherNumerics = [ g | Right g <- map Num.numericRight otherGs ]
@ -299,6 +410,7 @@ solveNumerics :: Num.Solver ->
[(Goal,Num.Prop)] {- ^ Consult these -} -> [(Goal,Num.Prop)] {- ^ Consult these -} ->
[(Goal,Num.Prop)] {- ^ Solve these -} -> [(Goal,Num.Prop)] {- ^ Solve these -} ->
IO [Goal] IO [Goal]
solveNumerics _ _ [] = return []
solveNumerics s consultGs solveGs = solveNumerics s consultGs solveGs =
Num.withScope s $ Num.withScope s $
do _ <- Num.assumeProps s (map (goal . fst) consultGs) do _ <- Num.assumeProps s (map (goal . fst) consultGs)
@ -439,7 +551,7 @@ improveByDefaultingWith s as ps =
su = listSubst defs su = listSubst defs
-- Do this to simplify the instantiated "fin" constraints. -- Do this to simplify the instantiated "fin" constraints.
(mb,su1) <- simpGoals' s (newOthers ++ others ++ apSubst su fins) (mb,su1) <- simpGoals' s Map.empty (newOthers ++ others ++ apSubst su fins)
case mb of case mb of
Right gs1 -> Right gs1 ->
let warn (x,t) = let warn (x,t) =
@ -536,7 +648,7 @@ defaultReplExpr so e s =
mbSubst <- tryGetModel so params (sProps s) mbSubst <- tryGetModel so params (sProps s)
case mbSubst of case mbSubst of
Just su -> Just su ->
do (res,su1) <- simpGoals' so (map (makeGoal su) (sProps s)) do (res,su1) <- simpGoals' so Map.empty (map (makeGoal su) (sProps s))
return $ return $
case res of case res of
Right [] | isEmptySubst su1 -> Right [] | isEmptySubst su1 ->
@ -600,35 +712,3 @@ simpTypeMaybe ty =
--------------------------------------------------------------------------------
_testSimpGoals :: IO ()
_testSimpGoals = Num.withSolver cfg $ \s ->
do _ <- Num.assumeProps s asmps
_mbImps <- Num.check s
(mb,_) <- simpGoals' s gs
case mb of
Right _ -> debugLog s "End of test"
Left _ -> debugLog s "Impossible"
where
cfg = SolverConfig { solverPath = "z3"
, solverArgs = [ "-smt2", "-in" ]
, solverVerbose = 1
}
asmps = []
gs = map fakeGoal [ tv 0 =#= tMin (num 10) (tv 1)
, tv 1 =#= num 10
]
fakeGoal p = Goal { goalSource = undefined, goalRange = undefined, goal = p }
tv n = TVar (TVFree n KNum Set.empty (text "test var"))
_btv n = TVar (TVBound n KNum)
num x = tNum (x :: Int)

View File

@ -0,0 +1,207 @@
-- | Look for opportunity to solve goals by instantiating variables.
module Cryptol.TypeCheck.Solver.Improve where
import qualified Data.Set as Set
import Control.Applicative
import Control.Monad
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.SimpType as Mk
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Subst
improveProps :: Bool -> Ctxt -> [Prop] -> Match (Subst,[Prop])
improveProps impSkol ctxt ps0 = loop emptySubst ps0
where
loop su props = case go emptySubst [] props of
(newSu,newProps)
| isEmptySubst newSu ->
if isEmptySubst su then mzero else return (su,props)
| otherwise -> loop (newSu @@ su) newProps
go su subs [] = (su,subs)
go su subs (p : ps) =
case matchMaybe (improveProp impSkol ctxt p) of
Nothing -> go su (p:subs) ps
Just (suNew,psNew) -> go (suNew @@ su) (psNew ++ subs) ps
improveProp :: Bool -> Ctxt -> Prop -> Match (Subst,[Prop])
improveProp impSkol ctxt prop =
improveEq impSkol ctxt prop
-- XXX: others
improveEq :: Bool -> Ctxt -> Prop -> Match (Subst,[Prop])
improveEq impSkol fins prop =
do (lhs,rhs) <- (|=|) prop
rewrite lhs rhs <|> rewrite rhs lhs
where
rewrite this other =
do x <- aTVar this
guard (considerVar x && x `Set.notMember` fvs other)
return (singleSubst x other, [])
<|>
do (v,s) <- isSum this
guard (v `Set.notMember` fvs other)
return (singleSubst v (Mk.tSub other s), [ other >== s ])
isSum t = do (v,s) <- matches t (anAdd, aTVar, __)
valid v s
<|> do (s,v) <- matches t (anAdd, __, aTVar)
valid v s
valid v s = do let i = typeInterval fins s
guard (considerVar v && v `Set.notMember` fvs s && iIsFin i)
return (v,s)
considerVar x = impSkol || isFreeTV x
--------------------------------------------------------------------------------
-- XXX
{-
-- | When given an equality constraint, attempt to rewrite it to the form `?x =
-- ...`, by moving all occurrences of `?x` to the LHS, and any other variables
-- to the RHS. This will only work when there's only one unification variable
-- present in the prop.
tryRewrteEqAsSubst :: Ctxt -> Type -> Type -> Maybe (TVar,Type)
tryRewrteEqAsSubst fins t1 t2 =
do let vars = Set.toList (Set.filter isFreeTV (fvs (t1,t2)))
listToMaybe $ sortBy (flip compare `on` rank)
$ catMaybes [ tryRewriteEq fins var t1 t2 | var <- vars ]
-- | Rank a rewrite, favoring expressions that have fewer subtractions than
-- additions.
rank :: (TVar,Type) -> Int
rank (_,ty) = go ty
where
go (TCon (TF TCAdd) ts) = sum (map go ts) + 1
go (TCon (TF TCSub) ts) = sum (map go ts) - 1
go (TCon (TF TCMul) ts) = sum (map go ts) + 1
go (TCon (TF TCDiv) ts) = sum (map go ts) - 1
go (TCon _ ts) = sum (map go ts)
go _ = 0
-- | Rewrite an equation with respect to a unification variable ?x, into the
-- form `?x = t`. There are two interesting cases to consider (four with
-- symmetry):
--
-- * ?x = ty
-- * expr containing ?x = expr
--
-- In the first case, we just return the type variable and the type, but in the
-- second we try to rewrite the equation until it's in the form of the first
-- case.
tryRewriteEq :: Map TVar Interval -> TVar -> Type -> Type -> Maybe (TVar,Type)
tryRewriteEq fins uvar l r =
msum [ do guard (uvarTy == l && uvar `Set.notMember` rfvs)
return (uvar, r)
, do guard (uvarTy == r && uvar `Set.notMember` lfvs)
return (uvar, l)
, do guard (uvar `Set.notMember` rfvs)
ty <- rewriteLHS fins uvar l r
return (uvar,ty)
, do guard (uvar `Set.notMember` lfvs)
ty <- rewriteLHS fins uvar r l
return (uvar,ty)
]
where
uvarTy = TVar uvar
lfvs = fvs l
rfvs = fvs r
-- | Check that a type contains only finite type variables.
allFin :: Map TVar Interval -> Type -> Bool
allFin ints ty = iIsFin (typeInterval ints ty)
-- | Rewrite an equality until the LHS is just `uvar`. Return the rewritten RHS.
--
-- There are a few interesting cases when rewriting the equality:
--
-- A o B = R when `uvar` is only present in A
-- A o B = R when `uvar` is only present in B
--
-- In the first case, as we only consider addition and subtraction, the
-- rewriting will continue on the left, after moving the `B` side to the RHS of
-- the equation. In the second case, if the operation is addition, the `A` side
-- will be moved to the RHS, with rewriting continuing in `B`. However, in the
-- case of subtraction, the `B` side is moved to the RHS, and rewriting
-- continues on the RHS instead.
--
-- In both cases, if the operation is addition, rewriting will only continue if
-- the operand being moved to the RHS is known to be finite. If this check was
-- not done, we would end up violating the well-definedness condition for
-- subtraction (for a, b: well defined (a - b) iff fin b).
rewriteLHS :: Map TVar Interval -> TVar -> Type -> Type -> Maybe Type
rewriteLHS fins uvar = go
where
go (TVar tv) rhs | tv == uvar = return rhs
go (TCon (TF tf) [x,y]) rhs =
do let xfvs = fvs x
yfvs = fvs y
inX = Set.member uvar xfvs
inY = Set.member uvar yfvs
if | inX && inY -> mzero
| inX -> balanceR x tf y rhs
| inY -> balanceL x tf y rhs
| otherwise -> mzero
-- discard type synonyms, the rewriting will make them no longer apply
go (TUser _ _ l) rhs =
go l rhs
-- records won't work here.
go _ _ =
mzero
-- invert the type function to balance the equation, when the variable occurs
-- on the LHS of the expression `x tf y`
balanceR x TCAdd y rhs = do guardFin y
go x (tSub rhs y)
balanceR x TCSub y rhs = go x (tAdd rhs y)
balanceR _ _ _ _ = mzero
-- invert the type function to balance the equation, when the variable occurs
-- on the RHS of the expression `x tf y`
balanceL x TCAdd y rhs = do guardFin y
go y (tSub rhs x)
balanceL x TCSub y rhs = go (tAdd rhs y) x
balanceL _ _ _ _ = mzero
-- guard that the type is finite
--
-- XXX this ignores things like `min x inf` where x is finite, and just
-- assumes that it won't work.
guardFin ty = guard (allFin fins ty)
-}

View File

@ -1,267 +1,177 @@
{-# LANGUAGE Trustworthy, PatternGuards, MultiWayIf #-} {-# LANGUAGE Safe, PatternGuards, MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Numeric module Cryptol.TypeCheck.Solver.Numeric
( cryIsEqual, cryIsNotEqual, cryIsGeq ( cryIsEqual, cryIsNotEqual, cryIsGeq
) where ) where
import Control.Monad (msum,guard,mzero) import Control.Applicative(Alternative(..))
import Data.Function (on) import Control.Monad (guard,mzero)
import Data.List (sortBy)
import Data.Maybe (catMaybes,listToMaybe)
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType
import Debug.Trace
cryIsEqual :: Map TVar Interval -> Type -> Type -> Solved cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual fin t1 t2 = cryIsEqual _ t1 t2 =
solveOpts matchDefault Unsolved $
[ pBin PEqual (==) fin t1 t2 (pBin PEqual (==) t1 t2)
, tIsNat' t1 `matchThen` \n -> tryEqK n t2 <|> (aNat' t1 >>= tryEqK t2)
, tIsNat' t2 `matchThen` \n -> tryEqK n t1 <|> (aNat' t2 >>= tryEqK t1)
, tIsVar t1 `matchThen` \tv -> tryEqInf tv t2 <|> (aTVar t1 >>= tryEqVar t2)
, tIsVar t2 `matchThen` \tv -> tryEqInf tv t1 <|> (aTVar t2 >>= tryEqVar t1)
, guarded (t1 == t2) $ SolvedIf [] <|> ( guard (t1 == t2) >> return (SolvedIf []))
-- x = min (K + x) y --> x = y
]
{-
case
Unsolved
| Just x <- tIsVar t1, isFreeTV x -> Unsolved
| Just n <- tIsNat' t1 -> tryEqK n t2
| Just n <- tIsNat' t2 -> tryEqK n t1
| Just (x,t) <- tryRewrteEqAsSubst fin t1 t2 ->
let new = show (pp x) ++ " == " ++ show (pp t)
in
trace ("Rewrote: " ++ sh ++ " -> " ++ new)
$ SolvedIf [ TCon (PC PEqual) [TVar x,t] ]
Unsolved -> trace ("Failed to rewrite eq: " ++ sh) Unsolved
x -> x
where
sh = show (pp t1) ++ " == " ++ show (pp t2)
-}
cryIsNotEqual :: Map TVar Interval -> Type -> Type -> Solved cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual = pBin PNeq (/=) cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
matchDefault Unsolved $
(pBin PGeq (>=) t1 t2)
<|> (aNat' t1 >>= tryGeqKThan i t2)
<|> (aTVar t2 >>= tryGeqThanVar i t1)
<|> tryGeqThanSub i t1 t2
<|> (geqByInterval i t1 t2)
<|> (guard (t1 == t2) >> return (SolvedIf []))
cryIsGeq :: Map TVar Interval -> Type -> Type -> Solved
cryIsGeq = pBin PGeq (>=)
-- XXX: max a 10 >= 2 --> True -- XXX: max a 10 >= 2 --> True
-- XXX: max a 2 >= 10 --> a >= 10 -- XXX: max a 2 >= 10 --> a >= 10
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Map TVar Interval -> pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
Type -> Type -> Solved pBin tf p t1 t2 =
pBin tf p _i t1 t2 Unsolvable <$> anError t1
| Just e <- tIsError t1 = Unsolvable e <|> Unsolvable <$> anError t2
| Just e <- tIsError t2 = Unsolvable e <|> (do x <- aNat' t1
| Just x <- tIsNat' t1 y <- aNat' t2
, Just y <- tIsNat' t2 = return $ if p x y
if p x y then SolvedIf []
then SolvedIf [] else Unsolvable $ TCErrorMessage
else Unsolvable $ TCErrorMessage
$ "Predicate " ++ show (pp tf) ++ " does not hold for " $ "Predicate " ++ show (pp tf) ++ " does not hold for "
++ show x ++ " and " ++ show y ++ show x ++ " and " ++ show y)
pBin _ _ _ _ _ = Unsolved
--------------------------------------------------------------------------------
-- GEQ
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =
do (a,b) <- aMul ty
m <- aNat' a
return $ SolvedIf
$ case m of
Inf -> [ b =#= tZero ]
Nat 0 -> []
Nat k -> [ tNum (div n k) >== b ]
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =
do (a,_) <- (|-|) y
guard (x == a)
return (SolvedIf [])
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
do (a,b) <- anAdd ty
let check y = do x' <- aTVar y
guard (x == x')
return (SolvedIf [])
check a <|> check b
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
let ix = typeInterval ctxt x
iy = typeInterval ctxt y
in case (iLower ix, iUpper iy) of
(l,Just n) | l >= n -> return (SolvedIf [])
_ -> mzero
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
tryEqInf :: TVar -> Type -> Solved
tryEqInf tv ty =
case tNoUser ty of
TCon (TF TCAdd) [a,b]
| Just n <- tIsNum a, n >= 1
, Just v <- tIsVar b, tv == v -> SolvedIf [ TVar tv =#= tInf ]
_ -> Unsolved
tryEqK :: Nat' -> Type -> Solved tryEqVar :: Type -> TVar -> Match Solved
tryEqK lk ty = tryEqVar ty x =
case tNoUser ty of
TCon (TF f) [ a, b ] | Just rk <- tIsNat' a ->
case f of
TCAdd -> -- x = K + x --> x = inf
case (lk,rk) of (do (k,tv) <- matches ty (anAdd, aNat, aTVar)
(_,Inf) -> Unsolved -- shouldn't happen, as `inf + x ` inf` guard (tv == x && k >= 1)
(Inf, Nat _) -> SolvedIf [ b =#= tInf ]
(Nat lk', Nat rk')
| lk' >= rk' -> SolvedIf [ b =#= tNum (lk' - rk') ]
| otherwise -> Unsolvable
$ TCErrorMessage
$ "Adding " ++ show rk' ++ " will always exceed "
++ show lk'
TCMul -> return $ SolvedIf [ TVar x =#= tInf ]
case (lk,rk) of )
(Inf,Inf) -> SolvedIf [ b >== tOne ] <|>
(Inf,Nat _) -> SolvedIf [ b =#= tInf ]
(Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
(Nat k, Inf) -> Unsolvable
$ TCErrorMessage $ show k ++ " /= inf * anything"
(Nat lk', Nat rk')
| rk' == 0 -> Unsolved --- shouldn't happen, as `0 * x = x`
| (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
| otherwise -> Unsolvable
$ TCErrorMessage
$ show lk ++ " /= " ++ show rk ++ " * anything"
-- XXX: Min, Max, etx -- x = min (K + x) y --> x = y
-- 2 = min (10,y) --> y = 2 (do (l,r) <- aMin ty
-- 2 = min (2,y) --> y >= 2 let check this other =
-- 10 = min (2,y) --> impossible do (k,x') <- matches this (anAdd, aNat', aTVar)
_ -> Unsolved guard (x == x' && k >= Nat 1)
return $ SolvedIf [ TVar x =#= other ]
check l r <|> check r l
_ -> Unsolved )
<|>
-- x = K + min a x
(do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
guard (k >= 1)
let check a b = do x' <- aTVar a
guard (x' == x)
return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
check l r <|> check r l
)
-- | When given an equality constraint, attempt to rewrite it to the form `?x =
-- ...`, by moving all occurrences of `?x` to the LHS, and any other variables
-- to the RHS. This will only work when there's only one unification variable
-- present in the prop.
tryRewrteEqAsSubst :: Map TVar Interval -> Type -> Type -> Maybe (TVar,Type)
tryRewrteEqAsSubst fins t1 t2 =
do let vars = Set.toList (Set.filter isFreeTV (fvs (t1,t2)))
listToMaybe $ sortBy (flip compare `on` rank)
$ catMaybes [ tryRewriteEq fins var t1 t2 | var <- vars ]
-- | Rank a rewrite, favoring expressions that have fewer subtractions than
-- additions.
rank :: (TVar,Type) -> Int
rank (_,ty) = go ty
where
go (TCon (TF TCAdd) ts) = sum (map go ts) + 1 -- e.g., 10 = t
go (TCon (TF TCSub) ts) = sum (map go ts) - 1 tryEqK :: Type -> Nat' -> Match Solved
go (TCon (TF TCMul) ts) = sum (map go ts) + 1 tryEqK ty lk =
go (TCon (TF TCDiv) ts) = sum (map go ts) - 1
go (TCon _ ts) = sum (map go ts) do (rk, b) <- matches ty (anAdd, aNat', __)
go _ = 0 return $
case nSub lk rk of
-- NOTE: (Inf - Inf) shouldn't be possible
Nothing -> Unsolvable
$ TCErrorMessage
$ "Adding " ++ show rk ++ " will always exceed "
++ show lk
Just r -> SolvedIf [ b =#= tNat' r ]
<|>
do (rk, b) <- matches ty (aMul, aNat', __)
return $
case (lk,rk) of
(Inf,Inf) -> SolvedIf [ b >== tOne ]
(Inf,Nat _) -> SolvedIf [ b =#= tInf ]
(Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
(Nat k, Inf) -> Unsolvable
$ TCErrorMessage
$ show k ++ " /= inf * anything"
(Nat lk', Nat rk')
| rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
-- shouldn't happen, as `0 * x = x`
| (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
| otherwise ->
Unsolvable
$ TCErrorMessage
$ show lk ++ " /= " ++ show rk ++ " * anything"
-- XXX: Min, Max, etx
-- 2 = min (10,y) --> y = 2
-- 2 = min (2,y) --> y >= 2
-- 10 = min (2,y) --> impossible
-- | Rewrite an equation with respect to a unification variable ?x, into the
-- form `?x = t`. There are two interesting cases to consider (four with
-- symmetry):
--
-- * ?x = ty
-- * expr containing ?x = expr
--
-- In the first case, we just return the type variable and the type, but in the
-- second we try to rewrite the equation until it's in the form of the first
-- case.
tryRewriteEq :: Map TVar Interval -> TVar -> Type -> Type -> Maybe (TVar,Type)
tryRewriteEq fins uvar l r =
msum [ do guard (uvarTy == l && uvar `Set.notMember` rfvs)
return (uvar, r)
, do guard (uvarTy == r && uvar `Set.notMember` lfvs)
return (uvar, l)
, do guard (uvar `Set.notMember` rfvs)
ty <- rewriteLHS fins uvar l r
return (uvar,ty)
, do guard (uvar `Set.notMember` lfvs)
ty <- rewriteLHS fins uvar r l
return (uvar,ty)
]
where
uvarTy = TVar uvar
lfvs = fvs l
rfvs = fvs r
-- | Check that a type contains only finite type variables.
allFin :: Map TVar Interval -> Type -> Bool
allFin ints ty = iIsFin (typeInterval ints ty)
-- | Rewrite an equality until the LHS is just `uvar`. Return the rewritten RHS.
--
-- There are a few interesting cases when rewriting the equality:
--
-- A o B = R when `uvar` is only present in A
-- A o B = R when `uvar` is only present in B
--
-- In the first case, as we only consider addition and subtraction, the
-- rewriting will continue on the left, after moving the `B` side to the RHS of
-- the equation. In the second case, if the operation is addition, the `A` side
-- will be moved to the RHS, with rewriting continuing in `B`. However, in the
-- case of subtraction, the `B` side is moved to the RHS, and rewriting
-- continues on the RHS instead.
--
-- In both cases, if the operation is addition, rewriting will only continue if
-- the operand being moved to the RHS is known to be finite. If this check was
-- not done, we would end up violating the well-definedness condition for
-- subtraction (for a, b: well defined (a - b) iff fin b).
rewriteLHS :: Map TVar Interval -> TVar -> Type -> Type -> Maybe Type
rewriteLHS fins uvar = go
where
go (TVar tv) rhs | tv == uvar = return rhs
go (TCon (TF tf) [x,y]) rhs =
do let xfvs = fvs x
yfvs = fvs y
inX = Set.member uvar xfvs
inY = Set.member uvar yfvs
if | inX && inY -> mzero
| inX -> balanceR x tf y rhs
| inY -> balanceL x tf y rhs
| otherwise -> mzero
-- discard type synonyms, the rewriting will make them no longer apply
go (TUser _ _ l) rhs =
go l rhs
-- records won't work here.
go _ _ =
mzero
-- invert the type function to balance the equation, when the variable occurs
-- on the LHS of the expression `x tf y`
balanceR x TCAdd y rhs = do guardFin y
go x (tSub rhs y)
balanceR x TCSub y rhs = go x (tAdd rhs y)
balanceR _ _ _ _ = mzero
-- invert the type function to balance the equation, when the variable occurs
-- on the RHS of the expression `x tf y`
balanceL x TCAdd y rhs = do guardFin y
go y (tSub rhs x)
balanceL x TCSub y rhs = go (tAdd rhs y) x
balanceL _ _ _ _ = mzero
-- guard that the type is finite
--
-- XXX this ignores things like `min x inf` where x is finite, and just
-- assumes that it won't work.
guardFin ty = guard (allFin fins ty)

View File

@ -20,7 +20,7 @@ module Cryptol.TypeCheck.Solver.Numeric.ImportExport
import Cryptol.TypeCheck.Solver.Numeric.AST import Cryptol.TypeCheck.Solver.Numeric.AST
import qualified Cryptol.TypeCheck.AST as Cry import qualified Cryptol.TypeCheck.AST as Cry
import qualified Cryptol.TypeCheck.SimpleSolver as SCry import qualified Cryptol.TypeCheck.SimpType as SCry
import MonadLib import MonadLib
exportProp :: Cry.Prop -> Maybe Prop exportProp :: Cry.Prop -> Maybe Prop

View File

@ -1,13 +1,19 @@
module Cryptol.TypeCheck.Solver.Types where module Cryptol.TypeCheck.Solver.Types where
import Data.Map(Map)
import Cryptol.TypeCheck.Type import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.PP import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Solver.Numeric.Interval
type Ctxt = Map TVar Interval
data Solved = SolvedIf [Prop] -- ^ Solved, assuming the sub-goals. data Solved = SolvedIf [Prop] -- ^ Solved, assuming the sub-goals.
| Unsolved -- ^ We could not solve the goal. | Unsolved -- ^ We could not solve the goal.
| Unsolvable TCErrorMessage -- ^ The goal can never be solved. | Unsolvable TCErrorMessage -- ^ The goal can never be solved.
deriving (Show) deriving (Show)
elseTry :: Solved -> Solved -> Solved elseTry :: Solved -> Solved -> Solved
Unsolved `elseTry` x = x Unsolved `elseTry` x = x
x `elseTry` _ = x x `elseTry` _ = x

View File

@ -8,8 +8,8 @@
module Cryptol.TypeCheck.Solver.Utils where module Cryptol.TypeCheck.Solver.Utils where
import Cryptol.TypeCheck.AST hiding (tAdd,tMul) import Cryptol.TypeCheck.AST hiding (tMul)
import Cryptol.TypeCheck.SimpleSolver(tAdd,tMul) import Cryptol.TypeCheck.SimpType(tAdd,tMul)
import Control.Monad(mplus,guard) import Control.Monad(mplus,guard)
import Data.Maybe(listToMaybe) import Data.Maybe(listToMaybe)

View File

@ -25,6 +25,7 @@ module Cryptol.TypeCheck.Subst
, apSubstTypeMapKeys , apSubstTypeMapKeys
, substBinds , substBinds
, applySubstToVar , applySubstToVar
, substToList
) where ) where
import Data.Maybe import Data.Maybe
@ -37,6 +38,7 @@ import qualified Data.Set as Set
import Cryptol.TypeCheck.AST import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.PP import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.TypeMap import Cryptol.TypeCheck.TypeMap
import qualified Cryptol.TypeCheck.SimpType as Simp
import qualified Cryptol.TypeCheck.SimpleSolver as Simp import qualified Cryptol.TypeCheck.SimpleSolver as Simp
import Cryptol.Utils.Panic(panic) import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Misc(anyJust) import Cryptol.Utils.Misc(anyJust)
@ -85,6 +87,11 @@ substBinds su
| suDefaulting su = Set.empty | suDefaulting su = Set.empty
| otherwise = Map.keysSet $ suMap su | otherwise = Map.keysSet $ suMap su
substToList :: Subst -> [(TVar,Type)]
substToList s
| suDefaulting s = panic "substToList" ["Defaulting substitution."]
| otherwise = Map.toList (suMap s)
instance PP (WithNames Subst) where instance PP (WithNames Subst) where
ppPrec _ (WithNames s mp) ppPrec _ (WithNames s mp)
| null els = text "(empty substitution)" | null els = text "(empty substitution)"
@ -123,7 +130,7 @@ apSubstMaybe su ty =
(TCLenFromThenTo,[t1,t2,t3]) -> Simp.tLenFromThenTo t1 t2 t3 (TCLenFromThenTo,[t1,t2,t3]) -> Simp.tLenFromThenTo t1 t2 t3
_ -> panic "apSubstMaybe" ["Unexpected type function", show t] _ -> panic "apSubstMaybe" ["Unexpected type function", show t]
PC _ -> Just $! Simp.simplify Map.empty (TCon t ss) PC _ ->Just $! Simp.simplify Map.empty (TCon t ss)
_ -> return (TCon t ss) _ -> return (TCon t ss)

View File

@ -486,12 +486,13 @@ tf2 f x y = TCon (TF f) [x,y]
tf3 :: TFun -> Type -> Type -> Type -> Type tf3 :: TFun -> Type -> Type -> Type -> Type
tf3 f x y z = TCon (TF f) [x,y,z] tf3 f x y z = TCon (TF f) [x,y,z]
{-
tAdd :: Type -> Type -> Type tAdd :: Type -> Type -> Type
tAdd x y tAdd x y
| Just x' <- tIsNum x | Just x' <- tIsNum x
, Just y' <- tIsNum y = error (show x' ++ " + " ++ show y') , Just y' <- tIsNum y = error (show x' ++ " + " ++ show y')
| otherwise = tf2 TCAdd x y | otherwise = tf2 TCAdd x y
-}
tSub :: Type -> Type -> Type tSub :: Type -> Type -> Type
tSub = tf2 TCSub tSub = tf2 TCSub
@ -511,9 +512,6 @@ tExp = tf2 TCExp
tMin :: Type -> Type -> Type tMin :: Type -> Type -> Type
tMin = tf2 TCMin tMin = tf2 TCMin
tMax :: Type -> Type -> Type
tMax = tf2 TCMax
tWidth :: Type -> Type tWidth :: Type -> Type
tWidth = tf1 TCWidth tWidth = tf1 TCWidth
@ -682,6 +680,7 @@ instance PP (WithNames Type) where
TRec fs -> braces $ fsep $ punctuate comma TRec fs -> braces $ fsep $ punctuate comma
[ pp l <+> text ":" <+> go 0 t | (l,t) <- fs ] [ pp l <+> text ":" <+> go 0 t | (l,t) <- fs ]
TUser c ts _ -> optParens (prec > 3) $ pp c <+> fsep (map (go 4) ts) TUser c ts _ -> optParens (prec > 3) $ pp c <+> fsep (map (go 4) ts)
-- TUser _ _ t -> ppPrec prec t -- optParens (prec > 3) $ pp c <+> fsep (map (go 4) ts)
TCon (TC tc) ts -> TCon (TC tc) ts ->
case (tc,ts) of case (tc,ts) of

View File

@ -0,0 +1,171 @@
module Cryptol.TypeCheck.TypePat
( aInf, aNat, aNat'
, anAdd, (|-|), aMul, (|^|), (|/|), (|%|)
, aMin, aMax
, aWidth
, aLenFromThen, aLenFromThenTo
, aTVar
, aBit
, aSeq
, aWord
, aChar
, aTuple
, (|->|)
, aFin, (|=|), (|/=|), (|>=|)
, aCmp, aArith
, aAnd
, aTrue
, anError
, module Cryptol.Utils.Patterns
) where
import Control.Applicative((<|>))
import Control.Monad
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.Solver.InfNat
tcon :: TCon -> ([Type] -> a) -> Pat Type a
tcon f p = \ty -> case tNoUser ty of
TCon c ts | f == c -> return (p ts)
_ -> mzero
ar0 :: [a] -> ()
ar0 ~[] = ()
ar1 :: [a] -> a
ar1 ~[a] = a
ar2 :: [a] -> (a,a)
ar2 ~[a,b] = (a,b)
ar3 :: [a] -> (a,a,a)
ar3 ~[a,b,c] = (a,b,c)
tf :: TFun -> ([Type] -> a) -> Pat Type a
tf f ar = tcon (TF f) ar
tc :: TC -> ([Type] -> a) -> Pat Type a
tc f ar = tcon (TC f) ar
tp :: PC -> ([Type] -> a) -> Pat Prop a
tp f ar = tcon (PC f) ar
--------------------------------------------------------------------------------
aInf :: Pat Type ()
aInf = tc TCInf ar0
aNat :: Pat Type Integer
aNat = \a -> case tNoUser a of
TCon (TC (TCNum n)) _ -> return n
_ -> mzero
aNat' :: Pat Type Nat'
aNat' = \a -> (Inf <$ aInf a)
<|> (Nat <$> aNat a)
anAdd :: Pat Type (Type,Type)
anAdd = tf TCAdd ar2
(|-|) :: Pat Type (Type,Type)
(|-|) = tf TCSub ar2
aMul :: Pat Type (Type,Type)
aMul = tf TCMul ar2
(|^|) :: Pat Type (Type,Type)
(|^|) = tf TCExp ar2
(|/|) :: Pat Type (Type,Type)
(|/|) = tf TCDiv ar2
(|%|) :: Pat Type (Type,Type)
(|%|) = tf TCMod ar2
aMin :: Pat Type (Type,Type)
aMin = tf TCMin ar2
aMax :: Pat Type (Type,Type)
aMax = tf TCMax ar2
aWidth :: Pat Type Type
aWidth = tf TCWidth ar1
aLenFromThen :: Pat Type (Type,Type,Type)
aLenFromThen = tf TCLenFromThen ar3
aLenFromThenTo :: Pat Type (Type,Type,Type)
aLenFromThenTo = tf TCLenFromThenTo ar3
--------------------------------------------------------------------------------
aTVar :: Pat Type TVar
aTVar = \a -> case tNoUser a of
TVar x -> return x
_ -> mzero
aBit :: Pat Type ()
aBit = tc TCBit ar0
aSeq :: Pat Type (Type,Type)
aSeq = tc TCSeq ar2
aWord :: Pat Type Type
aWord = \a -> do (l,t) <- aSeq a
aBit t
return l
aChar :: Pat Type ()
aChar = \a -> do (l,t) <- aSeq a
n <- aNat l
guard (n == 8)
aBit t
aTuple :: Pat Type [Type]
aTuple = \a -> case tNoUser a of
TCon (TC (TCTuple _)) ts -> return ts
_ -> mzero
(|->|) :: Pat Type (Type,Type)
(|->|) = tc TCFun ar2
--------------------------------------------------------------------------------
aFin :: Pat Prop Type
aFin = tp PFin ar1
(|=|) :: Pat Prop (Type,Type)
(|=|) = tp PEqual ar2
(|/=|) :: Pat Prop (Type,Type)
(|/=|) = tp PNeq ar2
(|>=|) :: Pat Prop (Type,Type)
(|>=|) = tp PGeq ar2
aCmp :: Pat Prop Type
aCmp = tp PCmp ar1
aArith :: Pat Prop Type
aArith = tp PArith ar1
aAnd :: Pat Prop (Prop,Prop)
aAnd = tp PAnd ar2
aTrue :: Pat Prop ()
aTrue = tp PTrue ar0
--------------------------------------------------------------------------------
anError :: Pat Type TCErrorMessage
anError = \a -> case tNoUser a of
TCon (TError _ err) _ -> return err
_ -> mzero

View File

@ -10,7 +10,7 @@
{-# LANGUAGE DeriveDataTypeable, RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable, RecordWildCards #-}
module Cryptol.Utils.Panic (panic) where module Cryptol.Utils.Panic (panic) where
import Cryptol.Version -- import Cryptol.Version
import Control.Exception as X import Control.Exception as X
import Data.Typeable(Typeable) import Data.Typeable(Typeable)
@ -29,7 +29,7 @@ instance Show CryptolPanic where
, "*** Please create an issue at https://github.com/galoisinc/cryptol/issues" , "*** Please create an issue at https://github.com/galoisinc/cryptol/issues"
, "" , ""
, "%< --------------------------------------------------- " , "%< --------------------------------------------------- "
] ++ rev ++ ]{- ++ rev ++
[ locLab ++ panicLoc p [ locLab ++ panicLoc p
, msgLab ++ fromMaybe "" (listToMaybe msgLines) , msgLab ++ fromMaybe "" (listToMaybe msgLines)
] ]
@ -47,7 +47,7 @@ instance Show CryptolPanic where
rev | null commitHash = [] rev | null commitHash = []
| otherwise = [ revLab ++ commitHash | otherwise = [ revLab ++ commitHash
, branchLab ++ commitBranch ++ dirtyLab ] , branchLab ++ commitBranch ++ dirtyLab ] -}
instance Exception CryptolPanic instance Exception CryptolPanic

View File

@ -0,0 +1,136 @@
{-# Language Safe, RankNTypes, MultiParamTypeClasses #-}
{-# Language FunctionalDependencies #-}
{-# Language FlexibleInstances #-}
{-# Language TypeFamilies, UndecidableInstances #-}
module Cryptol.Utils.Patterns where
import Control.Monad(liftM,liftM2,ap,MonadPlus(..),guard)
import Control.Applicative(Alternative(..))
newtype Match b = Match (forall r. r -> (b -> r) -> r)
instance Functor Match where
fmap = liftM
instance Applicative Match where
pure a = Match $ \_no yes -> yes a
(<*>) = ap
instance Monad Match where
fail _ = empty
Match m >>= f = Match $ \no yes -> m no $ \a ->
let Match n = f a in
n no yes
instance Alternative Match where
empty = Match $ \no _ -> no
Match m <|> Match n = Match $ \no yes -> m (n no yes) yes
instance MonadPlus Match where
type Pat a b = a -> Match b
(|||) :: Pat a b -> Pat a b -> Pat a b
p ||| q = \a -> p a <|> q a
-- | Check that a value satisfies multiple patterns.
-- For example, an "as" pattern is @(__ &&& p)@.
(&&&) :: Pat a b -> Pat a c -> Pat a (b,c)
p &&& q = \a -> liftM2 (,) (p a) (q a)
-- | Match a value, and modify the result.
(~>) :: Pat a b -> (b -> c) -> Pat a c
p ~> f = \a -> f <$> p a
-- | Match a value, and return the given result
(~~>) :: Pat a b -> c -> Pat a c
p ~~> f = \a -> f <$ p a
-- | View pattern.
(<~) :: (a -> b) -> Pat b c -> Pat a c
f <~ p = \a -> p (f a)
-- | Variable pattern.
__ :: Pat a a
__ = return
-- | Constant pattern.
succeed :: a -> Pat x a
succeed = const . return
-- | Predicate pattern
checkThat :: (a -> Bool) -> Pat a ()
checkThat p = \a -> guard (p a)
-- | Check for exact value.
lit :: Eq a => a -> Pat a ()
lit x = checkThat (x ==)
{-# Inline lit #-}
-- | Match a pattern, using the given default if valure.
matchDefault :: a -> Match a -> a
matchDefault a (Match m) = m a id
{-# Inline matchDefault #-}
-- | Match an irrefutable pattern. Crashes on faliure.
match :: Match a -> a
match m = matchDefault (error "Pattern match failure.") m
{-# Inline match #-}
matchMaybe :: Match a -> Maybe a
matchMaybe (Match m) = m Nothing Just
list :: [Pat a b] -> Pat [a] [b]
list [] = \a ->
case a of
[] -> return []
_ -> mzero
list (p : ps) = \as ->
case as of
[] -> mzero
x : xs ->
do a <- p x
bs <- list ps xs
return (a : bs)
(><) :: Pat a b -> Pat x y -> Pat (a,x) (b,y)
p >< q = \(a,x) -> do b <- p a
y <- q x
return (b,y)
class Matches thing pats res | pats -> thing res where
matches :: thing -> pats -> Match res
instance ( f ~ Pat a a1'
, a1 ~ Pat a1' r1
) => Matches a (f,a1) r1 where
matches ty (f,a1) = do a1' <- f ty
a1 a1'
instance ( op ~ Pat a (a1',a2')
, a1 ~ Pat a1' r1
, a2 ~ Pat a2' r2
) => Matches a (op,a1,a2) (r1,r2)
where
matches ty (f,a1,a2) = do (a1',a2') <- f ty
r1 <- a1 a1'
r2 <- a2 a2'
return (r1,r2)
instance ( op ~ Pat a (a1',a2',a3')
, a1 ~ Pat a1' r1
, a2 ~ Pat a2' r2
, a3 ~ Pat a3' r3
) => Matches a (op,a1,a2,a3) (r1,r2,r3) where
matches ty (f,a1,a2,a3) = do (a1',a2',a3') <- f ty
r1 <- a1 a1'
r2 <- a2 a2'
r3 <- a3 a3'
return (r1,r2,r3)