mixed checking order for better unification

on the `infer` function, we usually have something like this:

    infer (App f x) ctx =
      infer <- infer f ctx
      case infer of
        All a b -> do
          check x a ctx
          return $ out arg
        otherwise ->
          fail

this causes `x : a` to be happen before `infer (f x)` returns `b x`.
this is generally fine, but, in situations such as dependent
eliminations:

    λx (bool-elim ?A x t f) : (P x)

we really want `(elim ...)` to return `P x` BEFORE we check `t : ?A
true` and `t : ?A false`. that would allow the unification problem `?A x
== P x` to generate the solution `?A = λx (P x)` **before** the `t : ?A
true` check possibly fails.

being able to fill that metavar is very important for Kind2, since that
would allow us to omit motives in pattern-matches. because of that, I
think that the more sensible order is for infer to return its result
first, and then its inner checks occur. this is via a very lightweight
mechanism that consists of a list of suspended checks (`susp`), which we
push to inside `infer`, and fully consume inside `check`.

this is a middle-ground between checking in order (from left-to-right)
and a more complex suspension mechanism (involving dependency graphs).
with this simple solution, we're able to use metavars inside the motive
of dependent eliminations, greatly reducing the need for annotations in
practical code.
This commit is contained in:
Victor Taelin 2024-03-06 22:29:33 -03:00
parent edbaf3048d
commit f9a5bdb963
3 changed files with 142 additions and 150 deletions

View File

@ -1,8 +1,3 @@
Bool.lemma.notnot
: ∀(b: Bool) (Equal Bool (Bool.not (Bool.not b)) b)
= λb
(~b
λx (Equal Bool (Bool.not (Bool.not x)) x)
(Equal.refl Bool Bool.true)
(Equal.refl Bool Bool.false)
)
= λb (~b _ ?a (Equal.refl Bool Bool.false))

View File

@ -1,3 +1,8 @@
_main
: ∀(a: Bool) ∀(b: Bool) Bool
= λa λb (~a _ (~b _ ?A ?B) (~b _ ?C ?D))
: ∀(a: Bool) (List Bool)
= λa
let T = Bool.true
let F = Bool.false
let C = List.cons
let N = List.nil
(C _ T (C _ ?a (N _)))

View File

@ -10,6 +10,7 @@
import Data.Char (chr, ord)
import Prelude hiding (LT, GT, EQ)
import Debug.Trace
import Control.Monad (forM_)
-- Kind2 Types
-- -----------
@ -46,7 +47,8 @@ data Info
| Error Int Term Term Term Int
| Vague String
data State = State (Map Term) [Info] -- state type
data Check = Check Int Term Term Int
data State = State (Map Term) [Check] [Info] -- state type
data Res a = Done State a | Fail State -- result type
data Env a = Env (State -> Res a) -- environment computation
@ -180,10 +182,10 @@ envFail :: Env a
envFail = Env $ \state -> Fail state
envRun :: Env a -> Res a
envRun (Env chk) = chk (State mapNew [])
envRun (Env chk) = chk (State mapNew [] [])
envLog :: Info -> Env Int
envLog log = Env $ \(State fill logs) -> Done (State fill (log : logs)) 1
envLog log = Env $ \(State fill susp logs) -> Done (State fill susp (log : logs)) 1
envSnapshot :: Env State
envSnapshot = Env $ \state -> Done state state
@ -191,17 +193,17 @@ envSnapshot = Env $ \state -> Done state state
envRewind :: State -> Env Int
envRewind state = Env $ \_ -> Done state 0
envSusp :: Check -> Env ()
envSusp chk = Env $ \(State fill susp logs) -> Done (State fill (chk : susp) logs) ()
envFill :: Int -> Term -> Env ()
envFill k v = Env $ \(State fill logs) -> Done (State (mapSet (key k) v fill) logs) ()
envFill k v = Env $ \(State fill susp logs) -> Done (State (mapSet (key k) v fill) susp logs) ()
envGetFill :: Env (Map Term)
envGetFill = Env $ \(State fill logs) -> Done (State fill logs) fill
envGetFill = Env $ \(State fill susp logs) -> Done (State fill susp logs) fill
envTakeLogs :: Env [Info]
envTakeLogs = Env $ \(State fill logs) -> Done (State fill []) logs
envSetLogs :: [Info] -> Env ()
envSetLogs logs = Env $ \(State fill _) -> Done (State fill logs) ()
envTakeSusp :: Env [Check]
envTakeSusp = Env $ \(State fill susp logs) -> Done (State fill [] logs) susp
instance Functor Env where
fmap f (Env chk) = Env $ \logs -> case chk logs of
@ -324,13 +326,14 @@ termNormal fill lv term dep = termNormalGo fill lv (termReduce fill lv term) dep
-- Equality
-- --------
-- trace ("equal:\n- " ++ termShow a dep ++ "\n- " ++ termShow b dep) $ do
termEqual :: Term -> Term -> Int -> Env Bool
termEqual a b dep = do
fill <- envGetFill
let a' = termReduce fill 2 a
let b' = termReduce fill 2 b
termTryIdentical a' b' dep $ do
termSimilar a' b' dep
fill <- envGetFill
let a' = termReduce fill 2 a
let b' = termReduce fill 2 b
termTryIdentical a' b' dep $ do
termSimilar a' b' dep
termTryIdentical :: Term -> Term -> Int -> Env Bool -> Env Bool
termTryIdentical a b dep cont = do
@ -403,14 +406,14 @@ termIdenticalGo a (Ann bVal bTyp) dep =
-- envPure (aUid == bUid)
termIdenticalGo (Met aUid aSpn) b dep =
-- traceShow ("unify: " ++ show aUid ++ " x=" ++ termShow (Met aUid aSpn) dep ++ " t=" ++ termShow b dep) $
envBind (termUnify aUid aSpn b dep) $ \_ ->
envPure True
termUnify aUid aSpn b dep
termIdenticalGo a (Met bUid bSpn) dep =
-- traceShow ("unify: " ++ show bUid ++ " x=" ++ termShow (Met bUid bSpn) dep ++ " t=" ++ termShow a dep) $
envBind (termUnify bUid bSpn a dep) $ \_ ->
termUnify bUid bSpn a dep
termIdenticalGo (Hol aNam aCtx) b dep =
envPure True
termIdenticalGo a (Hol bNam bCtx) dep =
envPure True
termIdenticalGo (Hol aNam aCtx) (Hol bNam bCtx) dep =
envPure (aNam == bNam)
termIdenticalGo U60 U60 dep =
envPure True
termIdenticalGo (Num aVal) (Num bVal) dep =
@ -457,17 +460,18 @@ termIdenticalGo a b dep =
-- condition 3, and just allow recursive solutions.
-- If possible, solves a `(?X x y z ...) = K` problem, generating a subst.
termUnify :: Int -> [Term] -> Term -> Int -> Env ()
termUnify :: Int -> [Term] -> Term -> Int -> Env Bool
termUnify uid spn b dep = do
fill <- envGetFill
let unsolved = not (mapHas (key uid) fill)
let solvable = termUnifyValid fill spn []
if unsolved && solvable then do
let solution = termUnifySolve fill uid spn b
envLog (Solve uid solution dep)
-- trace ("solve: " ++ show uid ++ " " ++ termShow solution dep) $ do
envFill uid solution
return True
else
return ()
return False
-- Checks if an problem is solveable by pattern unification.
termUnifyValid :: Map Term -> [Term] -> [Int] -> Bool
@ -485,7 +489,9 @@ termUnifySolve fill uid (x : spn) b = case termReduce fill 0 x of
-- Substitutes a Bruijn level variable by a `neo` value in `term`.
termUnifySubst :: Int -> Term -> Term -> Term
-- termUnifySubst lvl neo term = term
termUnifySubst lvl neo (All nam inp bod) = All nam (termUnifySubst lvl neo inp) (\x -> termUnifySubst lvl neo (bod x))
termUnifySubst lvl neo (Lam nam bod) = Lam nam (\x -> termUnifySubst lvl neo (bod x))
termUnifySubst lvl neo (App fun arg) = App (termUnifySubst lvl neo fun) (termUnifySubst lvl neo arg)
termUnifySubst lvl neo (Ann val typ) = Ann (termUnifySubst lvl neo val) (termUnifySubst lvl neo typ)
termUnifySubst lvl neo (Slf nam typ bod) = Slf nam (termUnifySubst lvl neo typ) (\x -> termUnifySubst lvl neo (bod x))
@ -520,56 +526,54 @@ termInfer term dep =
termInferGo term dep
termInferGo :: Term -> Int -> Env Term
termInferGo (All nam inp bod) dep =
envBind (termCheck 0 inp Set dep) $ \_ ->
envBind (termCheck 0 (bod (Ann (Var nam dep) inp)) Set (dep + 1)) $ \_ ->
envPure Set
termInferGo (App fun arg) dep =
envBind (termInfer fun dep) $ \fun_typ ->
envBind envGetFill $ \fill ->
(termIfAll (termReduce fill 2 fun_typ)
(\fun_nam fun_typ_inp fun_typ_bod fun arg ->
envBind (termCheck 0 arg fun_typ_inp dep) $ \_ ->
envPure (fun_typ_bod arg))
(\fun arg -> do
envLog (Error 0 fun_typ (Hol "function" []) (App fun arg) dep)
envFail)
fun arg)
termInferGo (Ann val typ) dep =
envPure typ
termInferGo (Slf nam typ bod) dep =
envBind (termCheck 0 (bod (Ann (Var nam dep) typ)) Set (dep + 1)) $ \_ ->
envPure Set
termInferGo (Ins val) dep =
envBind (termInfer val dep) $ \vty ->
envBind envGetFill $ \fill ->
(termIfSlf (termReduce fill 2 vty)
(\vty_nam vty_typ vty_bod val ->
envPure (vty_bod (Ins val)))
(\val -> do
envLog (Error 0 vty (Hol "self-type" []) (Ins val) dep)
envFail)
val)
termInferGo (Ref nam val) dep =
termInferGo (All nam inp bod) dep = do
envSusp (Check 0 inp Set dep)
envSusp (Check 0 (bod (Ann (Var nam dep) inp)) Set (dep + 1))
return Set
termInferGo (App fun arg) dep = do
ftyp <- termInfer fun dep
fill <- envGetFill
case termReduce fill 2 ftyp of
(All ftyp_nam ftyp_inp ftyp_bod) -> do
envSusp (Check 0 arg ftyp_inp dep)
return $ ftyp_bod arg
otherwise -> do
envLog (Error 0 ftyp (Hol "function" []) (App fun arg) dep)
envFail
termInferGo (Ann val typ) dep = do
return typ
termInferGo (Slf nam typ bod) dep = do
envSusp (Check 0 (bod (Ann (Var nam dep) typ)) Set (dep + 1))
return Set
termInferGo (Ins val) dep = do
vtyp <- termInfer val dep
fill <- envGetFill
case termReduce fill 2 vtyp of
(Slf vtyp_nam vtyp_typ vtyp_bod) -> do
return $ vtyp_bod (Ins val)
otherwise -> do
envLog (Error 0 vtyp (Hol "self-type" []) (Ins val) dep)
envFail
termInferGo (Ref nam val) dep = do
termInfer val dep
termInferGo Set dep =
envPure Set
termInferGo U60 dep =
envPure Set
termInferGo (Num num) dep =
envPure U60
termInferGo (Txt txt) dep =
envPure xString
termInferGo (Op2 opr fst snd) dep =
envBind (termCheck 0 fst U60 dep) $ \_ ->
envBind (termCheck 0 snd U60 dep) $ \_ ->
envPure U60
termInferGo (Mat nam x z s p) dep =
envBind (termCheck 0 x U60 dep) $ \_ ->
envBind (termCheck 0 (p (Ann (Var nam dep) U60)) Set dep) $ \_ ->
envBind (termCheck 0 z (p (Num 0)) dep) $ \_ ->
envBind (termCheck 0 (s (Ann (Var (stringConcat nam "-1") dep) U60)) (p (Op2 ADD (Num 1) (Var (stringConcat nam "-1") dep))) (dep + 1)) $ \_ ->
envPure (p x)
termInferGo Set dep = do
return Set
termInferGo U60 dep = do
return Set
termInferGo (Num num) dep = do
return U60
termInferGo (Txt txt) dep = do
return xString
termInferGo (Op2 opr fst snd) dep = do
envSusp (Check 0 fst U60 dep)
envSusp (Check 0 snd U60 dep)
return U60
termInferGo (Mat nam x z s p) dep = do
envSusp (Check 0 x U60 dep)
envSusp (Check 0 (p (Ann (Var nam dep) U60)) Set dep)
envSusp (Check 0 z (p (Num 0)) dep)
envSusp (Check 0 (s (Ann (Var (stringConcat nam "-1") dep) U60)) (p (Op2 ADD (Num 1) (Var (stringConcat nam "-1") dep))) (dep + 1))
return (p x)
termInferGo (Lam nam bod) dep = do
envLog (Error 0 (Hol "untyped_term" []) (Hol "type_annotation" []) (Lam nam bod) dep)
envFail
@ -585,7 +589,7 @@ termInferGo (Met uid spn) dep = do
termInferGo (Var nam idx) dep = do
envLog (Error 0 (Hol "untyped_term" []) (Hol "type_annotation" []) (Var nam idx) dep)
envFail
termInferGo (Src src val) dep =
termInferGo (Src src val) dep = do
termInfer val dep
termCheck :: Int -> Term -> Term -> Int -> Env ()
@ -594,54 +598,56 @@ termCheck src val typ dep =
termCheckGo src val typ dep
termCheckGo :: Int -> Term -> Term -> Int -> Env ()
termCheckGo src (Lam termNam termBod) typx dep =
envBind envGetFill $ \fill ->
(termIfAll (termReduce fill 2 typx)
(\typeNam typeInp typeBod termBod ->
termCheckGo src (Lam termNam termBod) typx dep = do
fill <- envGetFill
case termReduce fill 2 typx of
(All typeNam typeInp typeBod) -> do
let ann = Ann (Var termNam dep) typeInp
term = termBod ann
typx = typeBod ann
in termCheck 0 term typx (dep + 1))
(\termBod ->
envBind (termInfer (Lam termNam termBod) dep) $ \x ->
envPure ())
termBod)
termCheckGo src (Ins termVal) typx dep =
envBind envGetFill $ \fill ->
(termIfSlf (termReduce fill 2 typx)
(\typeNam typeTyp typeBod termVal ->
termCheck 0 termVal (typeBod (Ins termVal)) dep)
(\termVal ->
envBind (termInfer (Ins termVal) dep) $ \x ->
envPure ())
termVal)
termCheckGo src (Let termNam termVal termBod) typx dep =
termCheck 0 (termBod termVal) typx (dep + 1)
termCheckGo src (Hol termNam termCtx) typx dep =
envBind (envLog (Found termNam typx termCtx dep)) $ \x ->
envPure ()
termCheckGo src (Met uid spn) typx dep =
envPure ()
termCheckGo src (Ref termNam (Ann termVal termTyp)) typx dep =
envBind (termEqual typx termTyp dep) $ \equal ->
termCheckReport src equal termTyp typx termVal dep
termCheckGo src (Src termSrc termVal) typx dep =
let term = termBod ann
let typx = typeBod ann
termCheck 0 term typx (dep + 1)
otherwise -> do
termInfer (Lam termNam termBod) dep
return ()
termCheckGo src (Ins termVal) typx dep = do
fill <- envGetFill
case termReduce fill 2 typx of
Slf typeNam typeTyp typeBod -> do
termCheck 0 termVal (typeBod (Ins termVal)) dep
_ -> do
termInfer (Ins termVal) dep
return ()
termCheckGo src (Let termNam termVal termBod) typx dep = do
termTyp <- termInfer termVal dep
termCheck 0 (termBod (Ann (Var termNam dep) termTyp)) typx dep
termCheckGo src (Hol termNam termCtx) typx dep = do
envLog (Found termNam typx termCtx dep)
return ()
termCheckGo src (Met uid spn) typx dep = do
return ()
-- termCheckGo src (Ref termNam (Ann termVal termTyp)) typx dep = do
-- equal <- termEqual typx termTyp dep
-- termCheckReport src equal termTyp typx termVal dep
termCheckGo src (Src termSrc termVal) typx dep = do
termCheck termSrc termVal typx dep
termCheckGo src term typx dep =
termCheckVerify src term typx dep
termCheckVerify :: Int -> Term -> Term -> Int -> Env ()
termCheckVerify src term typx dep =
envBind (termInfer term dep) $ \infer ->
envBind (termEqual typx infer dep) $ \equal ->
termCheckReport src equal infer typx term dep
termCheckReport :: Int -> Bool -> Term -> Term -> Term -> Int -> Env ()
termCheckReport src False detected expected value dep = do
envLog (Error src detected expected value dep)
envFail
termCheckReport src True detected expected value dep =
envPure ()
termCheckGo src term typx dep = do
infer <- termInfer term dep
equal <- termEqual typx infer dep
if equal then do
susp <- envTakeSusp
forM_ susp $ \(Check src val typ dep) -> do
termCheckGo src val typ dep
return ()
else do
envLog (Error src infer typx term dep)
envFail
-- termCheckReport :: Int -> Bool -> Term -> Term -> Term -> Int -> Env ()
-- termCheckReport src False detected expected value dep = do
-- envLog (Error src detected expected value dep)
-- envFail
-- termCheckReport src True detected expected value dep =
-- envPure ()
termCheckDef :: Term -> Env ()
termCheckDef (Ref nam (Ann val typ)) = termCheck 0 val typ 0 >> return ()
@ -701,8 +707,9 @@ termShow (Txt txt) dep = stringJoin [quote , txt , quote]
termShow (Hol nam ctx) dep = stringJoin ["?" , nam]
termShow (Met uid spn) dep = stringJoin ["(_", termShowSpn (reverse spn) dep, ")"]
termShow (Var nam idx) dep = nam
-- termShow (Src src val) dep = stringJoin ["!", (termShow val dep)]
-- termShow (Var nam idx) dep = stringJoin [nam, "^", show idx]
termShow (Src src val) dep = termShow val dep
-- termShow (Src src val) dep = stringJoin ["!", (termShow val dep)]
termShowSpn :: [Term] -> Int -> String
termShowSpn [] dep = ""
@ -739,9 +746,9 @@ infoShow fill (Found name typ ctx dep) =
let typ' = termShow (termNormal fill 1 typ dep) dep
ctx' = stringTail (contextShow fill ctx dep)
in stringJoin ["#found{", name, " ", typ', " [", ctx', "]}"]
infoShow fill (Error src detected expected value dep) =
let det = termShow (termNormal fill 1 detected dep) dep
exp = termShow (termNormal fill 1 expected dep) dep
infoShow fill (Error src expected detected value dep) =
let det = termShow detected dep
exp = termShow expected dep
val = termShow (termNormal fill 0 value dep) dep
in stringJoin ["#error{", exp, " ", det, " ", val, " ", u60Show src, "}"]
infoShow fill (Solve name term dep) =
@ -754,30 +761,15 @@ infoShow fill (Vague name) =
-- ---
apiCheck :: Term -> IO ()
apiCheck term = case envRun $ apiCheckGo 0 term of
apiCheck term = case envRun $ termCheckDef term of
Done state value -> apiPrintLogs state
Fail state -> apiPrintLogs state
apiCheckGo :: Int -> Term -> Env ()
apiCheckGo 9 term = return ()
apiCheckGo n term = do
termCheckDef term
logs <- envTakeLogs
if any infoIsSolve logs then
apiCheckGo (n + 1) term
else
envSetLogs logs
-- TODO: IMPLEMENT FUNCTION THAT SHOWS LIST OF INFOS
infosShow :: [Info] -> String
infosShow [] = ""
infosShow (info:infos) = infoShow mapNew info ++ "; " ++ infosShow infos
apiPrintLogs :: State -> IO ()
apiPrintLogs (State fill (log : logs)) = do
apiPrintLogs (State fill susp (log : logs)) = do
putStrLn $ infoShow fill log
apiPrintLogs (State fill logs)
apiPrintLogs (State fill []) = do
apiPrintLogs (State fill susp logs)
apiPrintLogs (State fill susp []) = do
return ()
-- Main