Idris2-boot/src/TTImp/ProcessDef.idr

203 lines
8.1 KiB
Idris
Raw Normal View History

module TTImp.ProcessDef
import Core.Context
import Core.Core
import Core.Env
import Core.Normalise
import Core.Value
import Core.UnifyState
import TTImp.Elab
import TTImp.Elab.Check
import TTImp.TTImp
mkPat' : List (Pat vars) -> Term vars -> Term vars -> Pat vars
mkPat' [] orig (Local fc c idx p) = PLoc fc idx p
mkPat' args orig (Ref fc (DataCon t a) n) = PCon fc n t a args
mkPat' args orig (Ref fc (TyCon t a) n) = PTyCon fc n a args
mkPat' [] orig (Bind fc x (Pi _ _ s) t)
= PArrow fc x (mkPat' [] s s) (mkPat' [] t t)
mkPat' args orig (App fc fn p arg)
= let parg = mkPat' [] arg arg in
mkPat' (parg :: args) orig fn
mkPat' args orig (As fc idx p ptm)
= let pat = mkPat' args orig ptm in
PAs fc idx p pat
mkPat' [] orig (PrimVal fc c) = PConst fc c
mkPat' args orig tm = PUnmatchable (getLoc orig) orig
mkPat : Term vars -> Pat vars
mkPat tm = mkPat' [] tm tm
-- Given a type checked LHS and its type, return the environment in which we
-- should check the RHS, the LHS and its type in that environment,
-- and a function which turns a checked RHS into a
-- pattern clause
extendEnv : Env Term vars ->
Term vars -> Term vars ->
(PatAlt vars -> a) ->
Core (vars' ** (Env Term vars',
Term vars', Term vars',
List (Pat vars') -> Term vars' -> a))
extendEnv env (Bind _ n (PVar c tmty) sc) (Bind _ n' (PVTy _ _) tysc) p with (nameEq n n')
extendEnv env (Bind _ n (PVar c tmty) sc) (Bind _ n' (PVTy _ _) tysc) p | Nothing
= throw (InternalError "Can't happen: names don't match in pattern type")
extendEnv env (Bind _ n (PVar c tmty) sc) (Bind _ n (PVTy _ _) tysc) p | (Just Refl)
= extendEnv (PVar c tmty :: env) sc tysc (\alt => p (CBind c n tmty alt))
extendEnv env tm ty p
= pure (_ ** (env, tm, ty, \pats, rhs => p (CPats pats rhs)))
-- Find names which are applied to a function in a Rig1/Rig0 position,
-- so that we know how they should be bound on the right hand side of the
-- pattern.
-- 'bound' counts the number of variables locally bound; these are the
-- only ones we're checking linearity of (we may be shadowing names if this
-- is a local definition, so we need to leave the earlier ones alone)
findLinear : {auto c : Ref Ctxt Defs} ->
Bool -> Nat -> RigCount -> Term vars ->
Core (List (Name, RigCount))
findLinear top bound rig (Bind fc n b sc)
= findLinear top (S bound) rig sc
findLinear top bound rig tm
= case getFnArgs tm of
(Ref _ _ n, []) => pure []
(Ref _ nt n, argsi)
=> do let args = map snd argsi
defs <- get Ctxt
Just nty <- lookupTyExact n (gamma defs)
| Nothing => pure []
findLinArg (accessible nt rig) !(nf defs [] nty) args
_ => pure []
where
accessible : NameType -> RigCount -> RigCount
accessible Func r = if top then r else Rig0
accessible _ r = r
findLinArg : RigCount -> NF [] -> List (Term vars) ->
Core (List (Name, RigCount))
findLinArg rig (NBind _ x (Pi c _ _) sc) (Local {name=a} fc _ idx prf :: as)
= if idx < bound
then do sc' <- sc (toClosure defaultOpts [] (Ref fc Bound x))
pure $ (a, rigMult c rig) ::
!(findLinArg rig sc' as)
else do sc' <- sc (toClosure defaultOpts [] (Ref fc Bound x))
findLinArg rig sc' as
findLinArg rig (NBind fc x (Pi c _ _) sc) (a :: as)
= pure $ !(findLinear False bound (rigMult c rig) a) ++
!(findLinArg rig !(sc (toClosure defaultOpts [] (Ref fc Bound x))) as)
findLinArg rig ty (a :: as)
= pure $ !(findLinear False bound rig a) ++ !(findLinArg rig ty as)
findLinArg _ _ [] = pure []
setLinear : List (Name, RigCount) -> Term vars -> Term vars
setLinear vs (Bind fc x (PVar c ty) sc)
= case lookup x vs of
Just c' => Bind fc x (PVar c' ty) (setLinear vs sc)
_ => Bind fc x (PVar c ty) (setLinear vs sc)
setLinear vs (Bind fc x (PVTy c ty) sc)
= case lookup x vs of
Just c' => Bind fc x (PVTy c' ty) (setLinear vs sc)
_ => Bind fc x (PVTy c ty) (setLinear vs sc)
setLinear vs tm = tm
-- Combining multiplicities on LHS:
-- Rig1 + Rig1/W not valid, since it means we have repeated use of name
-- Rig0 + RigW = RigW
-- Rig0 + Rig1 = Rig1
combineLinear : FC -> List (Name, RigCount) ->
Core (List (Name, RigCount))
combineLinear loc [] = pure []
combineLinear loc ((n, count) :: cs)
= case lookupAll n cs of
[] => pure $ (n, count) :: !(combineLinear loc cs)
counts => do count' <- combineAll count counts
pure $ (n, count') ::
!(combineLinear loc (filter notN cs))
where
notN : (Name, RigCount) -> Bool
notN (n', _) = n /= n'
lookupAll : Name -> List (Name, RigCount) -> List RigCount
lookupAll n [] = []
lookupAll n ((n', c) :: cs)
= if n == n' then c :: lookupAll n cs else lookupAll n cs
combine : RigCount -> RigCount -> Core RigCount
combine Rig1 Rig1 = throw (LinearUsed loc 2 n)
combine Rig1 RigW = throw (LinearUsed loc 2 n)
combine RigW Rig1 = throw (LinearUsed loc 2 n)
combine RigW RigW = pure RigW
combine Rig0 c = pure c
combine c Rig0 = pure c
combineAll : RigCount -> List RigCount -> Core RigCount
combineAll c [] = pure c
combineAll c (c' :: cs)
= do newc <- combine c c'
combineAll newc cs
-- Check a pattern clause, returning the component of the 'Case' expression it
-- represents, or Nothing if it's an impossible clause
export
checkClause : {auto c : Ref Ctxt Defs} ->
{auto u : Ref UST UState} ->
(mult : RigCount) -> (hashit : Bool) ->
Name -> Env Term vars ->
ImpClause -> Core (Maybe (PatAlt vars))
checkClause mult hashit n env (ImpossibleClause fc lhs)
= throw (InternalError "impossible not implemented yet")
checkClause mult hashit n env (PatClause fc lhs_in rhs)
= do lhs <- lhsInCurrentNS lhs_in
(lhstm, lhstyg) <- elabTerm n (InLHS mult) env
(IBindHere fc PATTERN lhs) Nothing
lhsty <- getTerm lhstyg
-- Normalise the LHS to get any functions or let bindings evaluated
-- (this might be allowed, e.g. for 'fromInteger')
defs <- get Ctxt
lhstm <- normalise defs (noLet env) lhstm
lhsty <- normaliseHoles defs env lhsty
linvars_in <- findLinear True 0 Rig1 lhstm
log 5 $ "Linearity of names in " ++ show n ++ ": " ++
show linvars_in
linvars <- combineLinear fc linvars_in
let lhstm_lin = setLinear linvars lhstm
let lhsty_lin = setLinear linvars lhsty
logTermNF 0 "LHS term" env lhstm_lin
logTermNF 0 "LHS type" env lhsty_lin
(vars' ** (env', lhstm', lhsty', mkAlt)) <-
extendEnv env lhstm_lin lhsty_lin id
defs <- get Ctxt
rhstm <- checkTerm n InExpr env' rhs (gnf defs env' lhsty')
logTermNF 0 "RHS term" env' rhstm
pure (Just (mkAlt (map (mkPat . snd) (getArgs lhstm')) rhstm))
where
noLet : Env Term vs -> Env Term vs
noLet [] = []
noLet (Let c v t :: env) = Lam c Explicit t :: noLet env
noLet (b :: env) = b :: noLet env
export
processDef : {auto c : Ref Ctxt Defs} ->
{auto u : Ref UST UState} ->
Env Term vars -> FC ->
Name -> List ImpClause -> Core ()
processDef env fc n_in cs_in
= do n <- inCurrentNS n_in
defs <- get Ctxt
Just gdef <- lookupCtxtExact n (gamma defs)
| Nothing => throw (NoDeclaration fc n)
let None = definition gdef
| _ => throw (AlreadyDefined fc n)
let ty = type gdef
let hashit = visibility gdef == Public
let mult = if multiplicity gdef == Rig0
then Rig0
else Rig1
cs <- traverse (checkClause mult hashit n env) cs_in
?something