Merge pull request #1448 from mattpolzin/case-tree-experiments-merge-upstream

Case tree heuristics
This commit is contained in:
Edwin Brady 2021-07-16 08:47:19 +01:00 committed by GitHub
commit a20aba63ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 190 additions and 19 deletions

View File

@ -68,7 +68,6 @@ sdocToTreeParser (SLine i rest) = case sdocToTreeParser rest of
(Just tree, rest') => (Just $ STConcat [STLine i, tree], rest')
(Nothing, rest') => (Just $ STLine i, rest')
sdocToTreeParser (SAnnPush ann rest) = case sdocToTreeParser rest of
(tree, Nothing) => (Nothing, Nothing)
(Just tree, Nothing) => (Just $ STAnn ann tree, Nothing)
(Just tree, Just rest') => case sdocToTreeParser rest' of
(Just tree', rest'') => (Just $ STConcat [STAnn ann tree, tree'], rest'')

View File

@ -6,11 +6,13 @@ import Core.Context.Log
import Core.Core
import Core.Env
import Core.Normalise
import Core.Options
import Core.TT
import Core.Value
import Libraries.Data.LengthMatch
import Data.List
import Data.Vect
import Decidable.Equality
@ -647,6 +649,139 @@ getFirstPat (p :: _) = pat p
getFirstArgType : NamedPats ns (p :: ps) -> ArgType ns
getFirstArgType (p :: _) = argType p
||| Store scores alongside rows of named patterns. These scores are used to determine
||| which column of patterns to switch on first. One score per column.
data ScoredPats : List Name -> List Name -> Type where
Scored : List (NamedPats ns (p :: ps)) -> Vect (length (p :: ps)) Int -> ScoredPats ns (p :: ps)
{ps : _} -> Show (ScoredPats ns ps) where
show (Scored xs ys) = (show ps) ++ "//" ++ (show ys)
zeroedScore : {ps : _} -> List (NamedPats ns (p :: ps)) -> ScoredPats ns (p :: ps)
zeroedScore nps = Scored nps (replicate (S $ length ps) 0)
||| Proof that a value `v` inserted in the middle of a list with
||| prefix `ps` and suffix `qs` can equivalently be snoced with
||| `ps` or consed with `qs` before appending `qs` to `ps`.
elemInsertedMiddle : (v : a) -> (ps,qs : List a) -> (ps ++ (v :: qs)) = ((ps `snoc` v) ++ qs)
elemInsertedMiddle v [] qs = Refl
elemInsertedMiddle v (x :: xs) qs = rewrite elemInsertedMiddle v xs qs in Refl
||| Helper to find a single highest scoring name (or none at all) while
||| retaining the context of all names processed.
highScore : {prev : List Name} ->
(names : List Name) ->
(scores : Vect (length names) Int) ->
(highVal : Int) ->
(highIdx : (n ** NVar n (prev ++ names))) ->
(duped : Bool) ->
Maybe (n ** NVar n (prev ++ names))
highScore [] [] high idx True = Nothing
highScore [] [] high idx False = Just idx
highScore (x :: xs) (y :: ys) high idx duped =
let next = highScore {prev = prev `snoc` x} xs ys
prf = elemInsertedMiddle x prev xs
in rewrite prf in
case compare y high of
LT => next high (rewrite sym $ prf in idx) duped
EQ => next high (rewrite sym $ prf in idx) True
GT => next y (x ** rewrite sym $ prf in weakenNVar (mkSizeOf prev) (MkNVar First)) False
||| Get the index of the highest scoring column if there is one.
||| If no column has a higher score than all other columns then
||| the result is Nothing indicating we need to apply more scoring
||| to break the tie.
||| Suggested heuristic application order: f, b, a.
highScoreIdx : {p : _} -> {ps : _} -> ScoredPats ns (p :: ps) -> Maybe (n ** NVar n (p :: ps))
highScoreIdx (Scored xs (y :: ys)) = highScore {prev = []} (p :: ps) (y :: ys) (y - 1) (p ** MkNVar First) False
||| Apply the penalty function to the head constructor's
||| arity. Produces 0 for all non-head-constructors.
headConsPenalty : (penality : Nat -> Int) -> Pat -> Int
headConsPenalty p (PAs _ _ w) = headConsPenalty p w
headConsPenalty p (PCon _ n _ arity pats) = p arity
headConsPenalty p (PTyCon _ _ arity _) = p arity
headConsPenalty _ (PConst _ _) = 0
headConsPenalty _ (PArrow _ _ _ _) = 0
headConsPenalty p (PDelay _ _ _ w) = headConsPenalty p w
headConsPenalty _ (PLoc _ _) = 0
headConsPenalty _ (PUnmatchable _ _) = 0
||| Apply the given function that scores a pattern to all patterns and then
||| sum up the column scores and add to the ScoredPats passed in.
consScoreHeuristic : {ps : _} -> (scorePat : Pat -> Int) -> ScoredPats ns ps -> ScoredPats ns ps
consScoreHeuristic _ sps@(Scored [] _) = sps -- can't update scores without any patterns
consScoreHeuristic scorePat (Scored xs ys) =
let columnScores = sum <$> scoreColumns xs
ys' = zipWith (+) ys columnScores
in Scored xs ys'
where
-- also returns NamePats of remaining columns while its in there
-- scoring the first column.
scoreFirstColumn : (nps : List (NamedPats ns (p' :: ps'))) -> (res : List (NamedPats ns ps') ** (LengthMatch nps res, Vect (length nps) Int))
scoreFirstColumn [] = ([] ** (NilMatch, []))
scoreFirstColumn ((w :: ws) :: nps) =
let (ws' ** (prf, scores)) = scoreFirstColumn nps
in (ws :: ws' ** (ConsMatch prf, scorePat (pat w) :: scores))
scoreColumns : {ps' : _} -> (nps : List (NamedPats ns ps')) -> Vect (length ps') (Vect (length nps) Int)
scoreColumns {ps' = []} nps = []
scoreColumns {ps' = (w :: ws)} nps =
let (rest ** (prf, firstColScore)) = scoreFirstColumn nps
in firstColScore :: (rewrite lengthsMatch prf in scoreColumns rest)
||| Add 1 to each non-default pat in the first row.
||| This favors constructive matching first and reduces tree depth on average.
heuristicF : {ps : _} -> ScoredPats ns (p :: ps) -> ScoredPats ns (p :: ps)
heuristicF sps@(Scored [] _) = sps
heuristicF (Scored (x :: xs) ys) =
let columnScores = scores x
ys' = zipWith (+) ys columnScores
in Scored (x :: xs) ys'
where
isBlank : Pat -> Bool
isBlank (PLoc _ _) = True
isBlank _ = False
scores : NamedPats ns' ps' -> Vect (length ps') Int
scores [] = []
scores (y :: ys) = let score : Int = if isBlank (pat y) then 0 else 1
in score :: scores ys
||| Subtract 1 from each column for each pat that represents a head constructor.
||| This favors pats that produce less branching.
heuristicB : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
heuristicB = consScoreHeuristic (headConsPenalty (\arity => if arity == 0 then 0 else -1))
||| Subtract the sum of the arities of constructors in each column.
heuristicA : {ps : _} -> ScoredPats ns ps -> ScoredPats ns ps
heuristicA = consScoreHeuristic (headConsPenalty (negate . cast))
applyHeuristics : {p : _} ->
{ps : _} ->
ScoredPats ns (p :: ps) ->
List (ScoredPats ns (p :: ps) -> ScoredPats ns (p :: ps)) ->
Maybe (n ** NVar n (p :: ps))
applyHeuristics x [] = highScoreIdx x
applyHeuristics x (f :: fs) = highScoreIdx x <|> applyHeuristics (f x) fs
||| Based only on the heuristic-score of columns, get the index of
||| the column that should be processed next.
|||
||| The scoring is inspired by results from the paper:
||| http://moscova.inria.fr/~maranget/papers/ml05e-maranget.pdf
nextIdxByScore : {p : _} ->
{ps : _} ->
(useHeuristics : Bool) ->
Phase ->
List (NamedPats ns (p :: ps)) ->
(n ** NVar n (p :: ps))
nextIdxByScore False _ _ = (_ ** (MkNVar First))
nextIdxByScore _ (CompileTime _) _ = (_ ** (MkNVar First))
nextIdxByScore True RunTime xs =
fromMaybe (_ ** (MkNVar First)) $
applyHeuristics (zeroedScore xs) [heuristicF, heuristicB, heuristicA]
-- Check whether all the initial patterns have the same concrete, known
-- and matchable type, which is multiplicity > 0.
-- If so, it's okay to match on it
@ -761,26 +896,26 @@ getScore fc phase name npss
CaseCompile _ _ err => pure $ Left err
err => throw err
-- Pick the leftmost matchable thing with all constructors in the
-- same family, or all variables, or all the same type constructor.
pickNext : {p, ns, ps : _} ->
||| Pick the leftmost matchable thing with all constructors in the
||| same family, or all variables, or all the same type constructor.
pickNextViable : {p, ns, ps : _} ->
{auto i : Ref PName Int} ->
{auto c : Ref Ctxt Defs} ->
FC -> Phase -> Name -> List (NamedPats ns (p :: ps)) ->
Core (n ** NVar n (p :: ps))
-- last possible variable
pickNext {ps = []} fc phase fn npss
pickNextViable {ps = []} fc phase fn npss
= if samePat npss
then pure (_ ** MkNVar First)
else do Right () <- getScore fc phase fn npss
| Left err => throw (CaseCompile fc fn err)
pure (_ ** MkNVar First)
pickNext {ps = q :: qs} fc phase fn npss
pickNextViable {ps = q :: qs} fc phase fn npss
= if samePat npss
then pure (_ ** MkNVar First)
else case !(getScore fc phase fn npss) of
Right () => pure (_ ** MkNVar First)
_ => do (_ ** MkNVar var) <- pickNext fc phase fn (map tail npss)
_ => do (_ ** MkNVar var) <- pickNextViable fc phase fn (map tail npss)
pure (_ ** MkNVar (Later var))
moveFirst : {idx : Nat} -> (0 el : IsVar nm idx ps) -> NamedPats ns ps ->
@ -789,6 +924,7 @@ moveFirst el nps = getPat el nps :: dropPat el nps
shuffleVars : {idx : Nat} -> (0 el : IsVar nm idx todo) -> PatClause vars todo ->
PatClause vars (nm :: dropVar todo el)
shuffleVars First orig@(MkPatClause pvars lhs pid rhs) = orig -- no-op
shuffleVars el (MkPatClause pvars lhs pid rhs)
= MkPatClause pvars (moveFirst el lhs) pid rhs
@ -807,11 +943,14 @@ mutual
Core (CaseTree vars)
-- Before 'partition', reorder the arguments so that the one we
-- inspect next has a concrete type that is the same in all cases, and
-- has the most distinct constructors (via pickNext)
-- has the most distinct constructors (via pickNextViable)
match {todo = (_ :: _)} fc fn phase clauses err
= do (n ** MkNVar next) <- pickNext fc phase fn (map getNPs clauses)
= do let nps = getNPs <$> clauses
let (_ ** (MkNVar next)) = nextIdxByScore (caseTreeHeuristics !getSession) phase nps
let prioritizedClauses = shuffleVars next <$> clauses
(n ** MkNVar next') <- pickNextViable fc phase fn (getNPs <$> prioritizedClauses)
log "compile.casetree.pick" 25 $ "Picked " ++ show n ++ " as the next split"
let clauses' = map (shuffleVars next) clauses
let clauses' = shuffleVars next' <$> prioritizedClauses
log "compile.casetree.clauses" 25 $
unlines ("Using clauses:" :: map ((" " ++) . show) clauses')
let ps = partition phase clauses'

View File

@ -45,6 +45,20 @@ mutual
||| Catch-all case
DefaultCase : CaseTree vars -> CaseAlt vars
mutual
public export
measure : CaseTree vars -> Nat
measure (Case idx p scTy xs) = sum $ measureAlts <$> xs
measure (STerm x y) = 0
measure (Unmatched msg) = 0
measure Impossible = 0
measureAlts : CaseAlt vars -> Nat
measureAlts (ConCase x tag args y) = 1 + (measure y)
measureAlts (DelayCase ty arg x) = 1 + (measure x)
measureAlts (ConstCase x y) = 1 + (measure y)
measureAlts (DefaultCase x) = 1 + (measure x)
export
isDefault : CaseAlt vars -> Bool
isDefault (DefaultCase _) = True

View File

@ -178,6 +178,7 @@ record Session where
-- Use whole program compilation for executables, no matter what
-- incremental CGs are set (intended for overriding any environment
-- variables that set incremental compilation)
caseTreeHeuristics : Bool -- apply heuristics to pick matches for case tree building
public export
record PPrinter where
@ -228,7 +229,7 @@ defaultSession : Session
defaultSession = MkSessionOpts False False False Chez [] 1000 False False
defaultLogLevel False False Nothing Nothing
Nothing Nothing False 1 False True
False [] False
False [] False False
export
defaultElab : ElabDirectives

View File

@ -50,6 +50,7 @@ knownTopics = [
("compile.casetree.clauses", Nothing),
("compile.casetree.getpmdef", Nothing),
("compile.casetree.intermediate", Nothing),
("compile.casetree.measure", Just "Log the node counts of each runtime case tree."),
("compile.casetree.pick", Nothing),
("compile.casetree.partition", Nothing),
("compiler.inline.eval", Nothing),

View File

@ -133,20 +133,24 @@ data CLOpt
DebugElabCheck |
AltErrorCount Nat |
BlodwenPaths |
||| Treat warnings as errors
||| Treat warnings as errors
WarningsAsErrors |
||| Do not print shadowing warnings
||| Do not print shadowing warnings
IgnoreShadowingWarnings |
||| Use SHA256 hashes to determine if a source file needs rebuilding instead
||| of modification time.
||| Use SHA256 hashes to determine if a source file needs rebuilding instead
||| of modification time.
HashesInsteadOfModTime |
||| Use incremental code generation, if the backend supports it
||| Apply experimental heuristics to case tree generation that
||| sometimes improves performance and reduces compiled code
||| size.
CaseTreeHeuristics |
||| Use incremental code generation, if the backend supports it
IncrementalCG String |
||| Use whole program compilation - overrides IncrementalCG if set
||| Use whole program compilation - overrides IncrementalCG if set
WholeProgram |
||| Generate bash completion info
||| Generate bash completion info
BashCompletion String String |
||| Generate bash completion script
||| Generate bash completion script
BashCompletionScript String
||| Extract the host and port to bind the IDE socket to
@ -240,6 +244,8 @@ options = [MkOpt ["--check", "-c"] [] [CheckOnly]
optSeparator,
MkOpt ["-Xcheck-hashes"] [] [HashesInsteadOfModTime]
(Just "Use SHA256 hashes instead of modification time to determine if a source file needs rebuilding"),
MkOpt ["-Xcase-tree-opt"] [] [CaseTreeHeuristics]
(Just "Apply experimental optimizations to case tree generation"),
optSeparator,
MkOpt ["--prefix"] [] [ShowPrefix]

View File

@ -735,6 +735,7 @@ partitionOpts opts = foldr pOptUpdate (MkPFR [] [] False) opts
optType Verbose = POpt
optType Timing = POpt
optType (Logging l) = POpt
optType CaseTreeHeuristics = POpt
optType (DumpCases f) = POpt
optType (DumpLifted f) = POpt
optType (DumpVMCode f) = POpt

View File

@ -382,6 +382,9 @@ preOptions (IgnoreShadowingWarnings :: opts)
preOptions (HashesInsteadOfModTime :: opts)
= do setSession (record { checkHashesInsteadOfModTime = True } !getSession)
preOptions opts
preOptions (CaseTreeHeuristics :: opts)
= do setSession (record { caseTreeHeuristics = True } !getSession)
preOptions opts
preOptions (IncrementalCG e :: opts)
= do defs <- get Ctxt
setIncrementalCG True e

View File

@ -14,3 +14,9 @@ checkLengthMatch [] (x :: xs) = Nothing
checkLengthMatch (x :: xs) [] = Nothing
checkLengthMatch (x :: xs) (y :: ys)
= Just (ConsMatch !(checkLengthMatch xs ys))
export
lengthsMatch : LengthMatch xs ys -> (length xs) = (length ys)
lengthsMatch NilMatch = Refl
lengthsMatch (ConsMatch x) = cong (S) (lengthsMatch x)

View File

@ -760,6 +760,7 @@ mkRunTime fc n
, show (indent 2 $ pretty {ann = ()} !(toFullNames tree_rt))
]
log "compile.casetree" 10 $ show tree_rt
log "compile.casetree.measure" 15 $ show (measure tree_rt)
let Just Refl = nameListEq cargs rargs
| Nothing => throw (InternalError "WAT")