initial implementation of 3 heuristics.

This commit is contained in:
Mathew Polzin 2021-05-02 00:26:17 -07:00
parent c725b11c89
commit 7ca774d8e6
3 changed files with 164 additions and 2 deletions

View File

@ -11,6 +11,9 @@ import Core.Value
import Libraries.Data.LengthMatch
import Data.List
import Data.Vect
import Debug.Trace
import Decidable.Equality
@ -94,6 +97,11 @@ data NamedPats : List Name -> -- pattern variables still to process
-- refers to are explicit
NamedPats vars ns -> NamedPats vars (pvar :: ns)
length : NamedPats vars ps -> (n ** n = length ps)
length [] = (0 ** Refl)
length (_ :: xs) = let (n' ** prf) = length xs
in ((S n') ** cong S prf)
getPatInfo : NamedPats vars todo -> List Pat
getPatInfo [] = []
getPatInfo (x :: xs) = pat x :: getPatInfo xs
@ -642,6 +650,135 @@ getFirstPat (p :: _) = pat p
getFirstArgType : NamedPats ns (p :: ps) -> ArgType ns
getFirstArgType (p :: _) = argType p
||| Store scores alongside rows of named patterns. These scores are used to determine
||| which column of patterns to switch on first. One score per column.
data ScoredPats : List Name -> List Name -> Type where
Scored : List (NamedPats ns (p :: ps)) -> Vect (length (p :: ps)) Int -> ScoredPats ns (p :: ps)
{ps : _} -> Show (ScoredPats ns ps) where
show (Scored xs ys) = (show ps) ++ "//" ++ (show ys)
||| Get proof that a value `v` inserted in the middle of a list with
||| prefix `ps` and suffix `qs` can equivalently be snoced with
||| `ps` or consed with `qs` before appending `qs` to `ps`.
elemInsertedMiddle : (v : a) -> (ps,qs : List a) -> (ps ++ (v :: qs)) = ((ps `snoc` v) ++ qs)
elemInsertedMiddle v [] qs = Refl
elemInsertedMiddle v (x :: xs) qs = rewrite elemInsertedMiddle v xs qs in Refl
||| Helper to find a single highest scoring name (or none at all) while
||| retaining the context of all names processed.
highScore : {prev : List Name} ->
(names : List Name) ->
(scores : Vect (length names) Int) ->
(highVal : Int) ->
(highIdx : (n ** NVar n (prev ++ names))) ->
(duped : Bool) ->
Maybe (n ** NVar n (prev ++ names))
highScore [] [] high idx True = Nothing
highScore [] [] high idx False = Just idx
highScore (x :: xs) (y :: ys) high idx duped =
let next = highScore {prev = prev `snoc` x} xs ys
prf = elemInsertedMiddle x prev xs
in rewrite prf in
case compare y high of
LT => next high (rewrite sym $ prf in idx) duped
EQ => next high (rewrite sym $ prf in idx) True
GT => next y (x ** rewrite sym $ prf in weakenNVar (mkSizeOf prev) (MkNVar First)) False
||| Get the index of the highest scoring column if there is one.
||| If no column has a higher score than all other columns then
||| the result is Nothing indicating we need to apply more scoring
||| to break the tie.
||| Suggested heuristic application order: f, b, a.
highScoreIdx : {p : _} -> {ps : _} -> ScoredPats ns (p :: ps) -> Maybe (n ** NVar n (p :: ps))
highScoreIdx (Scored xs (y :: ys)) = highScore {prev = []} (p :: ps) (y :: ys) (y - 1) (p ** MkNVar First) False
||| Turn a Vect into a list and proof that the list's
||| length is the same as the vector's length was.
toList' : Vect l a -> (res : List a ** length res = l)
toList' [] = ([] ** Refl)
toList' (x :: xs) =
let (rest ** prf) = toList' xs
in (x :: rest ** cong S prf)
||| Apply the penalty function to the head constructor's
||| arity. Produces 0 for all non-head-constructors.
headConsPenalty : (penality : Nat -> Int) -> Pat -> Int
headConsPenalty p (PAs _ _ w) = headConsPenalty p w
headConsPenalty p (PCon _ n _ arity pats) = p arity
headConsPenalty p (PTyCon _ _ arity _) = p arity
headConsPenalty _ (PConst _ _) = 0
headConsPenalty _ (PArrow _ _ _ _) = 0
headConsPenalty p (PDelay _ _ _ w) = headConsPenalty p w
headConsPenalty _ (PLoc _ _) = 0
headConsPenalty _ (PUnmatchable _ _) = 0
consScoreHeuristic : {ps : _} -> (scorePat : Pat -> Int) -> ScoredPats ns ps -> ScoredPats ns ps
consScoreHeuristic _ sps@(Scored [] _) = sps -- can't update scores without any patterns
consScoreHeuristic scorePat (Scored xs ys) =
let columnScores = sum <$> scoreColumns xs
ys' = zipWith (+) ys columnScores
in Scored xs ys'
where
-- also returns NamePats of remaining columns while its in there
-- scoring the first column.
scoreFirstColumn : (nps : List (NamedPats ns (p' :: ps'))) -> (Vect (length nps) (NamedPats ns ps'), Vect (length nps) Int)
scoreFirstColumn [] = ([], [])
scoreFirstColumn ((w :: ws) :: nps) =
let (ws', scores) = scoreFirstColumn nps
in (ws :: ws', scorePat (pat w) :: scores)
scoreColumns : {ps' : _} -> (nps : List (NamedPats ns ps')) -> Vect (length ps') (Vect (length nps) Int)
scoreColumns {ps' = []} nps = []
scoreColumns {ps' = (w :: ws)} nps =
let (rest, firstCol) = scoreFirstColumn nps
(rest' ** prf) = toList' rest
in firstCol :: (rewrite sym $ prf in scoreColumns rest')
zeroed : {ps : _} -> List (NamedPats ns (p :: ps)) -> ScoredPats ns (p :: ps)
zeroed _ = Scored [] (replicate (S $ length ps) 0)
||| Add 1 to each non-default pat in the first row.
||| This favors constructive matching first and reduces tree depth on average.
heuristicF : {ps : _} -> ScoredPats ns (p :: ps) -> ScoredPats ns (p :: ps)
heuristicF sps@(Scored [] _) = sps
heuristicF (Scored (x :: xs) ys) =
let columnScores = scores x
ys' = zipWith (+) ys columnScores
in Scored (x :: xs) (scores x)
where
isBlank : Pat -> Bool
isBlank (PLoc _ _) = True
isBlank _ = False
scores : NamedPats ns' ps' -> Vect (length ps') Int
scores [] = []
scores (y :: ys) = let score : Int = if isBlank (pat y) then 0 else 1
in score :: scores ys
||| Subtract 1 from each column for each pat that represents a head constructor.
||| This favors pats that produce less branching.
heuristicB : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
heuristicB = consScoreHeuristic (headConsPenalty (const $ -1))
||| Subtract the sum of the arities of constructors in each column.
heuristicA : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
heuristicA = consScoreHeuristic (headConsPenalty (negate . cast))
||| Based only on the heuristic-score of columns, get the index of
||| the column that should be processed next.
nextIdx : {p : _} -> {ps : _} -> List (NamedPats ns (p :: ps)) -> (n ** NVar n (p :: ps))
nextIdx xs =
let scored = heuristicF $ zeroed xs
in case highScoreIdx scored of
Just s => s
Nothing =>
let scored' = heuristicB scored
in case highScoreIdx scored' of
Just s => s
Nothing =>
fromMaybe (_ ** (MkNVar First)) $ highScoreIdx $ heuristicA scored'
-- Check whether all the initial patterns have the same concrete, known
-- and matchable type, which is multiplicity > 0.
-- If so, it's okay to match on it
@ -804,9 +941,19 @@ mutual
-- inspect next has a concrete type that is the same in all cases, and
-- has the most distinct constructors (via pickNext)
match {todo = (_ :: _)} fc fn phase clauses err
= do (n ** MkNVar next) <- pickNext fc phase fn (map getNPs clauses)
= do let nps = getNPs <$> clauses
let scores = heuristicF $ zeroed nps
log "compile.casetree.debug" 1 ("\nF: " ++ (show scores))
let scores' = heuristicB scores
log "compile.casetree.debug" 1 ("\nB: " ++ (show scores'))
let (n ** (MkNVar next)) = nextIdx nps
log "compile.casetree" 26 $ "Want " ++ show n ++ " as the next split based on alphabet heuristics"
let prioritizedClauses = shuffleVars next <$> clauses
(n ** MkNVar next') <- pickNext fc phase fn (getNPs <$> prioritizedClauses)
-- (n ** MkNVar next') <- pickNext fc phase fn (getNPs <$> clauses)
log "compile.casetree" 25 $ "Picked " ++ show n ++ " as the next split"
let clauses' = map (shuffleVars next) clauses
let clauses' = shuffleVars next' <$> prioritizedClauses
-- let clauses' = shuffleVars next' <$> clauses
log "compile.casetree" 25 $ "Using clauses " ++ show clauses'
let ps = partition phase clauses'
log "compile.casetree" 25 $ "Got Partition " ++ show ps

View File

@ -45,6 +45,20 @@ mutual
||| Catch-all case
DefaultCase : CaseTree vars -> CaseAlt vars
mutual
public export
measure : CaseTree vars -> Nat
measure (Case idx p scTy xs) = sum $ measureAlts <$> xs
measure (STerm x y) = 0
measure (Unmatched msg) = 0
measure Impossible = 0
measureAlts : CaseAlt vars -> Nat
measureAlts (ConCase x tag args y) = 1 + (measure y)
measureAlts (DelayCase ty arg x) = 1 + (measure x)
measureAlts (ConstCase x y) = 1 + (measure y)
measureAlts (DefaultCase x) = 1 + (measure x)
public export
data Pat : Type where
PAs : FC -> Name -> Pat -> Pat

View File

@ -744,6 +744,7 @@ mkRunTime fc n
, show (indent 2 $ pretty {ann = ()} !(toFullNames tree_rt))
]
log "compile.casetree" 10 $ show tree_rt
log "compile.casetree.measure" 1 $ show (measure tree_rt)
let Just Refl = nameListEq cargs rargs
| Nothing => throw (InternalError "WAT")