mirror of
https://github.com/idris-lang/Idris2.git
synced 2024-12-21 18:51:40 +03:00
initial implementation of 3 heuristics.
This commit is contained in:
parent
c725b11c89
commit
7ca774d8e6
@ -11,6 +11,9 @@ import Core.Value
|
||||
|
||||
import Libraries.Data.LengthMatch
|
||||
import Data.List
|
||||
import Data.Vect
|
||||
|
||||
import Debug.Trace
|
||||
|
||||
import Decidable.Equality
|
||||
|
||||
@ -94,6 +97,11 @@ data NamedPats : List Name -> -- pattern variables still to process
|
||||
-- refers to are explicit
|
||||
NamedPats vars ns -> NamedPats vars (pvar :: ns)
|
||||
|
||||
length : NamedPats vars ps -> (n ** n = length ps)
|
||||
length [] = (0 ** Refl)
|
||||
length (_ :: xs) = let (n' ** prf) = length xs
|
||||
in ((S n') ** cong S prf)
|
||||
|
||||
getPatInfo : NamedPats vars todo -> List Pat
|
||||
getPatInfo [] = []
|
||||
getPatInfo (x :: xs) = pat x :: getPatInfo xs
|
||||
@ -642,6 +650,135 @@ getFirstPat (p :: _) = pat p
|
||||
getFirstArgType : NamedPats ns (p :: ps) -> ArgType ns
|
||||
getFirstArgType (p :: _) = argType p
|
||||
|
||||
||| Store scores alongside rows of named patterns. These scores are used to determine
|
||||
||| which column of patterns to switch on first. One score per column.
|
||||
data ScoredPats : List Name -> List Name -> Type where
|
||||
Scored : List (NamedPats ns (p :: ps)) -> Vect (length (p :: ps)) Int -> ScoredPats ns (p :: ps)
|
||||
|
||||
{ps : _} -> Show (ScoredPats ns ps) where
|
||||
show (Scored xs ys) = (show ps) ++ "//" ++ (show ys)
|
||||
|
||||
||| Get proof that a value `v` inserted in the middle of a list with
|
||||
||| prefix `ps` and suffix `qs` can equivalently be snoced with
|
||||
||| `ps` or consed with `qs` before appending `qs` to `ps`.
|
||||
elemInsertedMiddle : (v : a) -> (ps,qs : List a) -> (ps ++ (v :: qs)) = ((ps `snoc` v) ++ qs)
|
||||
elemInsertedMiddle v [] qs = Refl
|
||||
elemInsertedMiddle v (x :: xs) qs = rewrite elemInsertedMiddle v xs qs in Refl
|
||||
|
||||
||| Helper to find a single highest scoring name (or none at all) while
|
||||
||| retaining the context of all names processed.
|
||||
highScore : {prev : List Name} ->
|
||||
(names : List Name) ->
|
||||
(scores : Vect (length names) Int) ->
|
||||
(highVal : Int) ->
|
||||
(highIdx : (n ** NVar n (prev ++ names))) ->
|
||||
(duped : Bool) ->
|
||||
Maybe (n ** NVar n (prev ++ names))
|
||||
highScore [] [] high idx True = Nothing
|
||||
highScore [] [] high idx False = Just idx
|
||||
highScore (x :: xs) (y :: ys) high idx duped =
|
||||
let next = highScore {prev = prev `snoc` x} xs ys
|
||||
prf = elemInsertedMiddle x prev xs
|
||||
in rewrite prf in
|
||||
case compare y high of
|
||||
LT => next high (rewrite sym $ prf in idx) duped
|
||||
EQ => next high (rewrite sym $ prf in idx) True
|
||||
GT => next y (x ** rewrite sym $ prf in weakenNVar (mkSizeOf prev) (MkNVar First)) False
|
||||
|
||||
||| Get the index of the highest scoring column if there is one.
|
||||
||| If no column has a higher score than all other columns then
|
||||
||| the result is Nothing indicating we need to apply more scoring
|
||||
||| to break the tie.
|
||||
||| Suggested heuristic application order: f, b, a.
|
||||
highScoreIdx : {p : _} -> {ps : _} -> ScoredPats ns (p :: ps) -> Maybe (n ** NVar n (p :: ps))
|
||||
highScoreIdx (Scored xs (y :: ys)) = highScore {prev = []} (p :: ps) (y :: ys) (y - 1) (p ** MkNVar First) False
|
||||
|
||||
||| Turn a Vect into a list and proof that the list's
|
||||
||| length is the same as the vector's length was.
|
||||
toList' : Vect l a -> (res : List a ** length res = l)
|
||||
toList' [] = ([] ** Refl)
|
||||
toList' (x :: xs) =
|
||||
let (rest ** prf) = toList' xs
|
||||
in (x :: rest ** cong S prf)
|
||||
|
||||
||| Apply the penalty function to the head constructor's
|
||||
||| arity. Produces 0 for all non-head-constructors.
|
||||
headConsPenalty : (penality : Nat -> Int) -> Pat -> Int
|
||||
headConsPenalty p (PAs _ _ w) = headConsPenalty p w
|
||||
headConsPenalty p (PCon _ n _ arity pats) = p arity
|
||||
headConsPenalty p (PTyCon _ _ arity _) = p arity
|
||||
headConsPenalty _ (PConst _ _) = 0
|
||||
headConsPenalty _ (PArrow _ _ _ _) = 0
|
||||
headConsPenalty p (PDelay _ _ _ w) = headConsPenalty p w
|
||||
headConsPenalty _ (PLoc _ _) = 0
|
||||
headConsPenalty _ (PUnmatchable _ _) = 0
|
||||
|
||||
consScoreHeuristic : {ps : _} -> (scorePat : Pat -> Int) -> ScoredPats ns ps -> ScoredPats ns ps
|
||||
consScoreHeuristic _ sps@(Scored [] _) = sps -- can't update scores without any patterns
|
||||
consScoreHeuristic scorePat (Scored xs ys) =
|
||||
let columnScores = sum <$> scoreColumns xs
|
||||
ys' = zipWith (+) ys columnScores
|
||||
in Scored xs ys'
|
||||
where
|
||||
-- also returns NamePats of remaining columns while its in there
|
||||
-- scoring the first column.
|
||||
scoreFirstColumn : (nps : List (NamedPats ns (p' :: ps'))) -> (Vect (length nps) (NamedPats ns ps'), Vect (length nps) Int)
|
||||
scoreFirstColumn [] = ([], [])
|
||||
scoreFirstColumn ((w :: ws) :: nps) =
|
||||
let (ws', scores) = scoreFirstColumn nps
|
||||
in (ws :: ws', scorePat (pat w) :: scores)
|
||||
|
||||
scoreColumns : {ps' : _} -> (nps : List (NamedPats ns ps')) -> Vect (length ps') (Vect (length nps) Int)
|
||||
scoreColumns {ps' = []} nps = []
|
||||
scoreColumns {ps' = (w :: ws)} nps =
|
||||
let (rest, firstCol) = scoreFirstColumn nps
|
||||
(rest' ** prf) = toList' rest
|
||||
in firstCol :: (rewrite sym $ prf in scoreColumns rest')
|
||||
|
||||
zeroed : {ps : _} -> List (NamedPats ns (p :: ps)) -> ScoredPats ns (p :: ps)
|
||||
zeroed _ = Scored [] (replicate (S $ length ps) 0)
|
||||
|
||||
||| Add 1 to each non-default pat in the first row.
|
||||
||| This favors constructive matching first and reduces tree depth on average.
|
||||
heuristicF : {ps : _} -> ScoredPats ns (p :: ps) -> ScoredPats ns (p :: ps)
|
||||
heuristicF sps@(Scored [] _) = sps
|
||||
heuristicF (Scored (x :: xs) ys) =
|
||||
let columnScores = scores x
|
||||
ys' = zipWith (+) ys columnScores
|
||||
in Scored (x :: xs) (scores x)
|
||||
where
|
||||
isBlank : Pat -> Bool
|
||||
isBlank (PLoc _ _) = True
|
||||
isBlank _ = False
|
||||
|
||||
scores : NamedPats ns' ps' -> Vect (length ps') Int
|
||||
scores [] = []
|
||||
scores (y :: ys) = let score : Int = if isBlank (pat y) then 0 else 1
|
||||
in score :: scores ys
|
||||
|
||||
||| Subtract 1 from each column for each pat that represents a head constructor.
|
||||
||| This favors pats that produce less branching.
|
||||
heuristicB : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
|
||||
heuristicB = consScoreHeuristic (headConsPenalty (const $ -1))
|
||||
|
||||
||| Subtract the sum of the arities of constructors in each column.
|
||||
heuristicA : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
|
||||
heuristicA = consScoreHeuristic (headConsPenalty (negate . cast))
|
||||
|
||||
||| Based only on the heuristic-score of columns, get the index of
|
||||
||| the column that should be processed next.
|
||||
nextIdx : {p : _} -> {ps : _} -> List (NamedPats ns (p :: ps)) -> (n ** NVar n (p :: ps))
|
||||
nextIdx xs =
|
||||
let scored = heuristicF $ zeroed xs
|
||||
in case highScoreIdx scored of
|
||||
Just s => s
|
||||
Nothing =>
|
||||
let scored' = heuristicB scored
|
||||
in case highScoreIdx scored' of
|
||||
Just s => s
|
||||
Nothing =>
|
||||
fromMaybe (_ ** (MkNVar First)) $ highScoreIdx $ heuristicA scored'
|
||||
|
||||
-- Check whether all the initial patterns have the same concrete, known
|
||||
-- and matchable type, which is multiplicity > 0.
|
||||
-- If so, it's okay to match on it
|
||||
@ -804,9 +941,19 @@ mutual
|
||||
-- inspect next has a concrete type that is the same in all cases, and
|
||||
-- has the most distinct constructors (via pickNext)
|
||||
match {todo = (_ :: _)} fc fn phase clauses err
|
||||
= do (n ** MkNVar next) <- pickNext fc phase fn (map getNPs clauses)
|
||||
= do let nps = getNPs <$> clauses
|
||||
let scores = heuristicF $ zeroed nps
|
||||
log "compile.casetree.debug" 1 ("\nF: " ++ (show scores))
|
||||
let scores' = heuristicB scores
|
||||
log "compile.casetree.debug" 1 ("\nB: " ++ (show scores'))
|
||||
let (n ** (MkNVar next)) = nextIdx nps
|
||||
log "compile.casetree" 26 $ "Want " ++ show n ++ " as the next split based on alphabet heuristics"
|
||||
let prioritizedClauses = shuffleVars next <$> clauses
|
||||
(n ** MkNVar next') <- pickNext fc phase fn (getNPs <$> prioritizedClauses)
|
||||
-- (n ** MkNVar next') <- pickNext fc phase fn (getNPs <$> clauses)
|
||||
log "compile.casetree" 25 $ "Picked " ++ show n ++ " as the next split"
|
||||
let clauses' = map (shuffleVars next) clauses
|
||||
let clauses' = shuffleVars next' <$> prioritizedClauses
|
||||
-- let clauses' = shuffleVars next' <$> clauses
|
||||
log "compile.casetree" 25 $ "Using clauses " ++ show clauses'
|
||||
let ps = partition phase clauses'
|
||||
log "compile.casetree" 25 $ "Got Partition " ++ show ps
|
||||
|
@ -45,6 +45,20 @@ mutual
|
||||
||| Catch-all case
|
||||
DefaultCase : CaseTree vars -> CaseAlt vars
|
||||
|
||||
mutual
|
||||
public export
|
||||
measure : CaseTree vars -> Nat
|
||||
measure (Case idx p scTy xs) = sum $ measureAlts <$> xs
|
||||
measure (STerm x y) = 0
|
||||
measure (Unmatched msg) = 0
|
||||
measure Impossible = 0
|
||||
|
||||
measureAlts : CaseAlt vars -> Nat
|
||||
measureAlts (ConCase x tag args y) = 1 + (measure y)
|
||||
measureAlts (DelayCase ty arg x) = 1 + (measure x)
|
||||
measureAlts (ConstCase x y) = 1 + (measure y)
|
||||
measureAlts (DefaultCase x) = 1 + (measure x)
|
||||
|
||||
public export
|
||||
data Pat : Type where
|
||||
PAs : FC -> Name -> Pat -> Pat
|
||||
|
@ -744,6 +744,7 @@ mkRunTime fc n
|
||||
, show (indent 2 $ pretty {ann = ()} !(toFullNames tree_rt))
|
||||
]
|
||||
log "compile.casetree" 10 $ show tree_rt
|
||||
log "compile.casetree.measure" 1 $ show (measure tree_rt)
|
||||
|
||||
let Just Refl = nameListEq cargs rargs
|
||||
| Nothing => throw (InternalError "WAT")
|
||||
|
Loading…
Reference in New Issue
Block a user