diff --git a/src/Cryptol/TypeCheck/SimpType.hs b/src/Cryptol/TypeCheck/SimpType.hs index 0192788d..4365bbe8 100644 --- a/src/Cryptol/TypeCheck/SimpType.hs +++ b/src/Cryptol/TypeCheck/SimpType.hs @@ -53,6 +53,8 @@ tAdd x y guard (b == y) return a) = v + | Just v <- matchMaybe (factor <|> same <|> swapVars) = v + | otherwise = tf2 TCAdd x y where isSumK t = case tNoUser t of @@ -63,8 +65,32 @@ tAdd x y addK 0 t = t addK n t | Just (m,b) <- isSumK t = tf2 TCAdd (tNum (n + m)) b + | Just v <- matchMaybe + $ do (a,b) <- (|-|) t + (do m <- aNat b + return $ case compare n m of + GT -> tAdd (tNum (n-m)) a + EQ -> a + LT -> tSub a (tNum (m-n))) + <|> + (do m <- aNat a + return (tSub (tNum (m+n)) b)) + = v | otherwise = tf2 TCAdd (tNum n) t + factor = do (a,b1) <- aMul x + (a',b2) <- aMul y + guard (a == a') + return (tMul a (tAdd b1 b2)) + + same = do guard (x == y) + return (tMul (tNum (2 :: Int)) x) + + swapVars = do a <- aTVar x + b <- aTVar y + guard (b < a) + return (tf2 TCAdd y x) + tSub :: Type -> Type -> Type tSub x y | Just t <- tOp TCSub (op2 nSub) [x,y] = t @@ -97,6 +123,7 @@ tMul x y | Just t <- tOp TCMul (total (op2 nMul)) [x,y] = t | Just n <- tIsNum x = mulK n y | Just n <- tIsNum y = mulK n x + | Just v <- matchMaybe swapVars = v | otherwise = tf2 TCMul x y where mulK 0 _ = tNum (0 :: Int) @@ -114,6 +141,13 @@ tMul x y | otherwise = tf2 TCMul (tNum n) t where t' = tNoUser t + swapVars = do a <- aTVar x + b <- aTVar y + guard (b < a) + return (tf2 TCMul y x) + + + tDiv :: Type -> Type -> Type tDiv x y | Just t <- tOp TCDiv (op2 nDiv) [x,y] = t @@ -143,10 +177,17 @@ tMin x y | Just t <- tOp TCMin (total (op2 nMin)) [x,y] = t | Just n <- tIsNat' x = minK n y | Just n <- tIsNat' y = minK n x + | Just n <- matchMaybe (minPlusK x y <|> minPlusK y x) = n | x == y = x -- XXX: min (k + t) t -> t | otherwise = tf2 TCMin x y where + minPlusK a b = do (l,r) <- anAdd a + k <- aNat l + guard (k >= 1 && b == r) + return b + + minK Inf t = t minK (Nat 0) _ = tNum (0 :: Int) minK (Nat k) t @@ -234,8 +275,10 @@ tOp tf f ts where err xs = TCErrorMessage $ "Invalid applicatoin of " ++ show (pp tf) ++ " to " ++ - unwords (map show xs) - + unwords (map ppIN xs) + + ppIN Inf = "inf" + ppIN (Nat x) = show x diff --git a/src/Cryptol/TypeCheck/Solver/Improve.hs b/src/Cryptol/TypeCheck/Solver/Improve.hs index 55b14863..22eb7579 100644 --- a/src/Cryptol/TypeCheck/Solver/Improve.hs +++ b/src/Cryptol/TypeCheck/Solver/Improve.hs @@ -54,6 +54,7 @@ improveEq impSkol fins prop = guard (v `Set.notMember` fvs other) return (singleSubst v (Mk.tSub other s), [ other >== s ]) + isSum t = do (v,s) <- matches t (anAdd, aTVar, __) valid v s <|> do (s,v) <- matches t (anAdd, __, aTVar) diff --git a/src/Cryptol/TypeCheck/Solver/Numeric.hs b/src/Cryptol/TypeCheck/Solver/Numeric.hs index 92f5d78f..afdd1576 100644 --- a/src/Cryptol/TypeCheck/Solver/Numeric.hs +++ b/src/Cryptol/TypeCheck/Solver/Numeric.hs @@ -17,14 +17,16 @@ import Cryptol.TypeCheck.SimpType cryIsEqual :: Ctxt -> Type -> Type -> Solved -cryIsEqual _ t1 t2 = +cryIsEqual ctxt t1 t2 = matchDefault Unsolved $ (pBin PEqual (==) t1 t2) - <|> (aNat' t1 >>= tryEqK t2) - <|> (aNat' t2 >>= tryEqK t1) + <|> (aNat' t1 >>= tryEqK ctxt t2) + <|> (aNat' t2 >>= tryEqK ctxt t1) <|> (aTVar t1 >>= tryEqVar t2) <|> (aTVar t2 >>= tryEqVar t1) <|> ( guard (t1 == t2) >> return (SolvedIf [])) + <|> tryEqMin t1 t2 + <|> tryEqMin t2 t1 @@ -36,6 +38,7 @@ cryIsGeq i t1 t2 = matchDefault Unsolved $ (pBin PGeq (>=) t1 t2) <|> (aNat' t1 >>= tryGeqKThan i t2) + <|> (aNat' t2 >>= tryGeqThanK i t1) <|> (aTVar t2 >>= tryGeqThanVar i t1) <|> tryGeqThanSub i t1 t2 <|> (geqByInterval i t1 t2) @@ -73,6 +76,17 @@ tryGeqKThan _ ty (Nat n) = Nat 0 -> [] Nat k -> [ tNum (div n k) >== b ] +tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved +tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ]) +tryGeqThanK _ t (Nat k) = + do (a,b) <- anAdd t + n <- aNat a + return $ SolvedIf $ if n >= k + then [] + else [ b >== tNum (k - n) ] + + + tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved tryGeqThanSub _ x y = do (a,_) <- (|-|) y @@ -98,6 +112,13 @@ geqByInterval ctxt x y = -------------------------------------------------------------------------------- +tryEqMin :: Type -> Type -> Match Solved +tryEqMin x y = + do (a,b) <- aMin x + let check m1 m2 = do guard (m1 == y) + return $ SolvedIf [ m2 >== m1 ] + check a b <|> check b a + tryEqVar :: Type -> TVar -> Match Solved tryEqVar ty x = @@ -135,9 +156,14 @@ tryEqVar ty x = -- e.g., 10 = t -tryEqK :: Type -> Nat' -> Match Solved -tryEqK ty lk = - +tryEqK :: Ctxt -> Type -> Nat' -> Match Solved +tryEqK ctxt ty lk = + do guard (lk == Inf) + (a,b) <- anAdd ty + let check x y = do guard (iIsFin (typeInterval ctxt x)) + return $ SolvedIf [ y =#= tInf ] + check a b <|> check b a + <|> do (rk, b) <- matches ty (anAdd, aNat', __) return $ case nSub lk rk of diff --git a/tests/issues/issue101.icry.stdout b/tests/issues/issue101.icry.stdout index d09f30e3..d3dfb2e7 100644 --- a/tests/issues/issue101.icry.stdout +++ b/tests/issues/issue101.icry.stdout @@ -1,8 +1,4 @@ Loading module Cryptol [error] at :1:1--1:11: - Unsolvable constraint: - 0 >= 1 - arising from - use of partial type function - - at :1:7--1:10 + Invalid applicatoin of - to 0 1 diff --git a/tests/issues/issue226.icry.stdout b/tests/issues/issue226.icry.stdout index 846e0bc8..b45f7bd2 100644 --- a/tests/issues/issue226.icry.stdout +++ b/tests/issues/issue226.icry.stdout @@ -8,7 +8,7 @@ Type Synonyms type Char = [8] type String n = [n][8] type Word n = [n] - type lg2 n = width (max n 1 - 1) + type lg2 n = width (max 1 n - 1) Symbols ======= @@ -60,8 +60,8 @@ Symbols last == len) => [len][bits] fromTo : {first, last, bits} (fin last, fin bits, last >= first, bits >= width last) => [1 + (last - first)][bits] - groupBy : {each, parts, elem} (fin each) => [parts * - each]elem -> [parts][each]elem + groupBy : {each, parts, elem} (fin each) => [each * + parts]elem -> [parts][each]elem infFrom : {bits} (fin bits) => [bits] -> [inf][bits] infFromThen : {bits} (fin bits) => [bits] -> [bits] -> [inf][bits] join : {parts, each, a} (fin each) => [parts][each]a -> [parts * @@ -72,7 +72,8 @@ Symbols negate : {a} (Arith a) => a -> a pdiv : {a, b} (fin a, fin b) => [a] -> [b] -> [a] pmod : {a, b} (fin a, fin b) => [a] -> [1 + b] -> [b] - pmult : {a, b} (fin a, fin b) => [1 + a] -> [1 + b] -> [1 + a + b] + pmult : {a, b} (fin a, fin b) => [1 + a] -> [1 + b] -> [1 + + (a + b)] random : {a} [256] -> a reverse : {a, b} (fin a) => [a]b -> [a]b split : {parts, each, a} (fin each) => [parts * diff --git a/tests/mono-binds/test01.icry.stdout b/tests/mono-binds/test01.icry.stdout index fbbf2ad1..bcb9eaaf 100644 --- a/tests/mono-binds/test01.icry.stdout +++ b/tests/mono-binds/test01.icry.stdout @@ -5,15 +5,14 @@ module test01 import Cryptol /* Not recursive */ test01::a : {a, b} (fin a) => [a]b -> [2 * a]b -test01::a = \{a, b} (fin a) -> - (\ (x : [a]b) -> - test01::f a x - where - /* Not recursive */ - test01::f : {c} [c]b -> [a + c]b - test01::f = \{c} (y : [c]b) -> (Cryptol::#) a c b <> x y - - ) : [a]b -> [2 * a]b +test01::a = \{a, b} (fin a) (x : [a]b) -> + test01::f a x + where + /* Not recursive */ + test01::f : {c} [c]b -> [a + c]b + test01::f = \{c} (y : [c]b) -> (Cryptol::#) a c b <> x y + + Loading module Cryptol Loading module test01 @@ -21,13 +20,12 @@ module test01 import Cryptol /* Not recursive */ test01::a : {a, b} (fin a) => [a]b -> [2 * a]b -test01::a = \{a, b} (fin a) -> - (\ (x : [a]b) -> - test01::f x - where - /* Not recursive */ - test01::f : [a]b -> [a + a]b - test01::f = \ (y : [a]b) -> (Cryptol::#) a a b <> x y - - ) : [a]b -> [2 * a]b +test01::a = \{a, b} (fin a) (x : [a]b) -> + test01::f x + where + /* Not recursive */ + test01::f : [a]b -> [2 * a]b + test01::f = \ (y : [a]b) -> (Cryptol::#) a a b <> x y + + diff --git a/tests/mono-binds/test05.icry.stdout b/tests/mono-binds/test05.icry.stdout index 22fecca5..6d4adc8b 100644 --- a/tests/mono-binds/test05.icry.stdout +++ b/tests/mono-binds/test05.icry.stdout @@ -32,10 +32,9 @@ test05::test = \{a, b, c} (fin c, c >= 4) (a : [a]b) -> test05::foo = a /* Not recursive */ - test05::bar : {e} (fin e) => [a + e]b + test05::bar : {e} (fin e) => [e + a]b test05::bar = \{e} (fin e) -> - (Cryptol::#) e a b <> (Cryptol::zero ([e]b)) test05::foo : [a + - e]b + (Cryptol::#) e a b <> (Cryptol::zero ([e]b)) test05::foo @@ -70,7 +69,7 @@ test05::test = \{a, b, c} (fin c, c >= 4) (a : [a]b) -> test05::foo = Cryptol::demote 10 10 <> <> <> /* Not recursive */ - test05::f : [0 + a]b + test05::f : [a]b test05::f = test05::bar where /* Not recursive */ @@ -78,7 +77,7 @@ test05::test = \{a, b, c} (fin c, c >= 4) (a : [a]b) -> test05::foo = a /* Not recursive */ - test05::bar : [0 + a]b + test05::bar : [a]b test05::bar = (Cryptol::#) 0 a b <> (Cryptol::zero ([0]b)) test05::foo