mutual tail recursion

This commit is contained in:
Rui Barreiro 2020-10-05 14:39:38 +01:00
parent d56b090c4c
commit a5b2247f28
5 changed files with 302 additions and 79 deletions

View File

@ -211,9 +211,9 @@ compileToImperative c tm =
cdata <- getCompileData Cases tm
let ndefs = namedDefs cdata
let ctm = forget (mainExpr cdata)
s <- newRef Imps (MkImpSt 0)
newRef Imps (MkImpSt 0)
lst_defs <- traverse getImp (defsUsedByNamedCExp ctm ndefs)
let defs = concat lst_defs
let defs_optim = tailRecOptim defs
defs_optim <- tailRecOptim defs
(s, main) <- impExp False ctm
pure $ (defs_optim, s <+> EvalExpStatement main)

View File

@ -14,7 +14,7 @@ mutual
| IEPrimFnExt Name (List ImperativeExp)
| IEConstructorHead ImperativeExp
| IEConstructorTag (Either Int String)
| IEConstructorArg Int ImperativeExp
| IEConstructorArg Int ImperativeExp -- constructor arg index starts at 1
| IEConstructor (Either Int String) (List ImperativeExp)
| IEDelay ImperativeExp
| IEForce ImperativeExp
@ -80,64 +80,99 @@ mutual
mutual
public export
replaceNamesExp : List (Name, ImperativeExp) -> ImperativeExp -> ImperativeExp
replaceNamesExp reps (IEVar n) =
case lookup n reps of
Nothing => IEVar n
Just e => e
replaceNamesExp reps (IELambda args body) =
IELambda args $ replaceNamesExpS reps body
replaceNamesExp reps (IEApp f args) =
IEApp (replaceNamesExp reps f) (replaceNamesExp reps <$> args)
replaceNamesExp reps (IEConstant c) =
IEConstant c
replaceNamesExp reps (IEPrimFn f args) =
IEPrimFn f (replaceNamesExp reps <$> args)
replaceNamesExp reps (IEPrimFnExt f args) =
IEPrimFnExt f (replaceNamesExp reps <$> args)
replaceNamesExp reps (IEConstructorHead e) =
IEConstructorHead $ replaceNamesExp reps e
replaceNamesExp reps (IEConstructorTag i) =
IEConstructorTag i
replaceNamesExp reps (IEConstructorArg i e) =
IEConstructorArg i (replaceNamesExp reps e)
replaceNamesExp reps (IEConstructor t args) =
IEConstructor t (replaceNamesExp reps <$> args)
replaceNamesExp reps (IEDelay e) =
IEDelay $ replaceNamesExp reps e
replaceNamesExp reps (IEForce e) =
IEForce $ replaceNamesExp reps e
replaceNamesExp reps IENull =
IENull
replaceExp : (ImperativeExp -> Maybe ImperativeExp) -> ImperativeExp -> ImperativeExp
replaceExp rep x@(IEVar n) =
case rep x of
Just z => z
Nothing => x
replaceExp rep x@(IELambda args body) =
case rep x of
Just z => z
Nothing => IELambda args $ replaceExpS rep body
replaceExp rep x@(IEApp f args) =
case rep x of
Just z => z
Nothing => IEApp (replaceExp rep f) (replaceExp rep <$> args)
replaceExp rep x@(IEConstant c) =
case rep x of
Just z => z
Nothing => x
replaceExp rep x@(IEPrimFn f args) =
case rep x of
Just z => z
Nothing => IEPrimFn f (replaceExp rep <$> args)
replaceExp rep x@(IEPrimFnExt f args) =
case rep x of
Just z => z
Nothing => IEPrimFnExt f (replaceExp rep <$> args)
replaceExp rep x@(IEConstructorHead e) =
case rep x of
Just z => z
Nothing => IEConstructorHead $ replaceExp rep e
replaceExp rep x@(IEConstructorTag i) =
case rep x of
Just z => z
Nothing => x
replaceExp rep x@(IEConstructorArg i e) =
case rep x of
Just z => z
Nothing => IEConstructorArg i (replaceExp rep e)
replaceExp rep x@(IEConstructor t args) =
case rep x of
Just z => z
Nothing => IEConstructor t (replaceExp rep <$> args)
replaceExp rep x@(IEDelay e) =
case rep x of
Just z => z
Nothing => IEDelay $ replaceExp rep e
replaceExp rep x@(IEForce e) =
case rep x of
Just z => z
Nothing => IEForce $ replaceExp rep e
replaceExp rep x@IENull =
case rep x of
Just z => z
Nothing => x
public export
replaceNamesExpS : List (Name, ImperativeExp) -> ImperativeStatement -> ImperativeStatement
replaceNamesExpS reps DoNothing =
replaceExpS : (ImperativeExp -> Maybe ImperativeExp) -> ImperativeStatement -> ImperativeStatement
replaceExpS rep DoNothing =
DoNothing
replaceNamesExpS reps (SeqStatement x y) =
SeqStatement (replaceNamesExpS reps x) (replaceNamesExpS reps y)
replaceNamesExpS reps (FunDecl fc n args body) =
FunDecl fc n args $ replaceNamesExpS reps body
replaceNamesExpS reps (ForeignDecl n path) =
replaceExpS rep (SeqStatement x y) =
SeqStatement (replaceExpS rep x) (replaceExpS rep y)
replaceExpS rep (FunDecl fc n args body) =
FunDecl fc n args $ replaceExpS rep body
replaceExpS rep (ForeignDecl n path) =
ForeignDecl n path
replaceNamesExpS reps (ReturnStatement e) =
ReturnStatement $ replaceNamesExp reps e
replaceNamesExpS reps (SwitchStatement s alts def) =
let s_ = replaceNamesExp reps s
alts_ = (\(e,b) => (replaceNamesExp reps e, replaceNamesExpS reps b)) <$> alts
def_ = replaceNamesExpS reps <$> def
replaceExpS rep (ReturnStatement e) =
ReturnStatement $ replaceExp rep e
replaceExpS rep (SwitchStatement s alts def) =
let s_ = replaceExp rep s
alts_ = (\(e,b) => (replaceExp rep e, replaceExpS rep b)) <$> alts
def_ = replaceExpS rep <$> def
in SwitchStatement s_ alts_ def_
replaceNamesExpS reps (LetDecl n v) =
LetDecl n $ replaceNamesExp reps <$> v
replaceNamesExpS reps (ConstDecl n v) =
ConstDecl n $ replaceNamesExp reps v
replaceNamesExpS reps (MutateStatement n v) =
MutateStatement n $ replaceNamesExp reps v
replaceNamesExpS reps (ErrorStatement s) =
replaceExpS rep (LetDecl n v) =
LetDecl n $ replaceExp rep <$> v
replaceExpS rep (ConstDecl n v) =
ConstDecl n $ replaceExp rep v
replaceExpS rep (MutateStatement n v) =
MutateStatement n $ replaceExp rep v
replaceExpS rep (ErrorStatement s) =
ErrorStatement s
replaceNamesExpS reps (EvalExpStatement x) =
EvalExpStatement $ replaceNamesExp reps x
replaceNamesExpS reps (CommentStatement x) =
replaceExpS rep (EvalExpStatement x) =
EvalExpStatement $ replaceExp rep x
replaceExpS rep (CommentStatement x) =
CommentStatement x
replaceNamesExpS reps (ForEverLoop x) =
ForEverLoop $ replaceNamesExpS reps x
replaceExpS rep (ForEverLoop x) =
ForEverLoop $ replaceExpS rep x
rep : List (Name, ImperativeExp) -> ImperativeExp -> Maybe ImperativeExp
rep reps (IEVar n) =
lookup n reps
rep _ _ = Nothing
public export
replaceNamesExpS : List (Name, ImperativeExp) -> ImperativeStatement -> ImperativeStatement
replaceNamesExpS reps x =
replaceExpS (rep reps) x

View File

@ -1,26 +1,52 @@
module Compiler.ES.TailRec
import Data.Maybe
import Data.List
import Data.Strings
import Data.SortedSet
import Data.SortedMap
import Core.Name
import Core.Context
import Compiler.ES.ImperativeAst
hasTailCall : Name -> ImperativeStatement -> Bool
hasTailCall n (SeqStatement x y) =
hasTailCall n y
hasTailCall n (ReturnStatement x) =
import Debug.Trace
data TailRecS : Type where
record TailSt where
constructor MkTailSt
nextName : Int
genName : {auto c : Ref TailRecS TailSt} -> Core Name
genName =
do
s <- get TailRecS
let i = nextName s
put TailRecS (record { nextName = i + 1 } s)
pure $ MN "imp_gen_tailoptim" i
allTailCalls : ImperativeStatement -> SortedSet Name
allTailCalls (SeqStatement x y) =
allTailCalls y
allTailCalls (ReturnStatement x) =
case x of
IEApp (IEVar cn) _ => n == cn
_ => False
hasTailCall n (SwitchStatement e alts d) =
(any (\(_, w)=>hasTailCall n w) alts) || (maybe False (hasTailCall n) d)
hasTailCall n (ForEverLoop x) =
hasTailCall n x
hasTailCall n o = False
IEApp (IEVar n) _ => insert n empty
_ => empty
allTailCalls (SwitchStatement e alts d) =
maybe empty allTailCalls d `union` concat (map allTailCalls (map snd alts))
allTailCalls (ForEverLoop x) =
allTailCalls x
allTailCalls o = empty
argsName : Name
argsName = MN "tailRecOptimArgs" 0
argsName = MN "imp_gen_tailoptim_Args" 0
stepFnName : Name
stepFnName = MN "tailRecOptimStep" 0
stepFnName = MN "imp_gen_tailoptim_step" 0
fusionArgsName : Name
fusionArgsName = MN "imp_gen_tailoptim_fusion_args" 0
createNewArgs : List ImperativeExp -> ImperativeExp
createNewArgs values =
@ -40,8 +66,6 @@ replaceTailCall n (ReturnStatement x) =
if n == cn then ReturnStatement $ createNewArgs arg_vals
else defRet
_ => defRet
replaceTailCall n (SwitchStatement e alts d) =
SwitchStatement e (map (\(x,y) => (x, replaceTailCall n y)) alts) (map (replaceTailCall n) d)
replaceTailCall n (ForEverLoop x) =
@ -60,11 +84,169 @@ makeTailOptimToBody n argNames body =
loop = ForEverLoop $ SwitchStatement (IEConstructorHead $ IEVar argsName) [(IEConstructorTag $ Left 0, loopRec)] (Just loopReturn)
in stepFn <+> createArgInit argNames <+> loop
record CallGraph where
constructor MkCallGraph
calls, callers : SortedMap Name (SortedSet Name)
showCallGraph : CallGraph -> String
showCallGraph x =
let calls = unlines $ map
(\(x,y) => show x ++ ": " ++ show (SortedSet.toList y))
(SortedMap.toList x.calls)
callers = unlines $ map
(\(x,y) => show x ++ ": " ++ show (SortedSet.toList y))
(SortedMap.toList x.callers)
in calls ++ "\n----\n" ++ callers
tailCallGraph : ImperativeStatement -> CallGraph
tailCallGraph (SeqStatement x y) =
let xg = tailCallGraph x
yg = tailCallGraph y
in MkCallGraph
(mergeWith union xg.calls yg.calls)
(mergeWith union xg.callers yg.callers)
tailCallGraph (FunDecl fc n args body) =
let calls = allTailCalls body
in MkCallGraph (insert n calls empty) (SortedMap.fromList $ map (\x => (x, SortedSet.insert n empty)) (SortedSet.toList calls))
tailCallGraph _ = MkCallGraph empty empty
kosarajuL : SortedSet Name -> List Name -> CallGraph -> (SortedSet Name, List Name)
kosarajuL visited [] graph =
(visited, [])
kosarajuL visited (x::xs) graph =
if contains x visited then kosarajuL visited xs graph
else let outNeighbours = maybe [] SortedSet.toList $ lookup x (graph.calls)
(visited_, l_) = kosarajuL (insert x visited) (toList outNeighbours) graph
(visited__, l__) = kosarajuL visited_ xs graph
in (visited__, l__ ++ (x :: l_))
kosarajuAssign : CallGraph -> Name -> Name -> SortedMap Name Name -> SortedMap Name Name
kosarajuAssign graph u root s =
case lookup u s of
Just _ => s
Nothing => let inNeighbours = maybe [] SortedSet.toList $ lookup u (graph.callers)
in foldl (\acc, elem => kosarajuAssign graph elem root acc) (insert u root s) inNeighbours
kosaraju: CallGraph -> SortedMap Name Name
kosaraju graph =
let l = snd $ kosarajuL empty (keys $ graph.calls) graph
in foldl (\acc, elem => kosarajuAssign graph elem elem acc) empty l
groupBy : (a -> a -> Bool) -> List a -> List (List a)
groupBy _ [] = []
groupBy p' (x'::xs') =
let (ys',zs') = go p' x' xs'
in (x' :: ys') :: zs'
where
go : (a -> a -> Bool) -> a -> List a -> (List a, List (List a))
go p z (x::xs) =
let (ys,zs) = go p x xs
in case p z x of
True => (x :: ys, zs)
False => ([], (x :: ys) :: zs)
go _ _ [] = ([], [])
recursiveTailCallGroups : CallGraph -> List (List Name)
recursiveTailCallGroups graph =
let roots = kosaraju graph
groups = map (map fst) $
groupBy (\x,y => Builtin.snd x == Builtin.snd y) $
sortBy (\x,y=> compare (snd x) (snd y)) $
toList roots
in [x | x<-groups, hasTailCalls x]
where
hasTailCalls : List Name -> Bool
hasTailCalls [] =
False
hasTailCalls [x] =
case lookup x (graph.calls) of
Nothing =>
False
Just s =>
case SortedSet.toList s of
[n] => n == x
_ => False
hasTailCalls _ =
True
extractFunctions : SortedSet Name -> ImperativeStatement ->
(ImperativeStatement, SortedMap Name (FC, List Name, ImperativeStatement))
extractFunctions toExtract (SeqStatement x y) =
let (xs, xf) = extractFunctions toExtract x
(ys, yf) = extractFunctions toExtract y
in (xs <+> ys, mergeLeft xf yf)
extractFunctions toExtract f@(FunDecl fc n args body) =
if contains n toExtract then (neutral, insert n (fc, args, body) empty)
else (f, empty)
extractFunctions toExtract x =
(x, empty)
getDef : SortedMap Name (FC, List Name, ImperativeStatement) -> Name -> Core (FC, List Name, ImperativeStatement)
getDef defs n =
case lookup n defs of
Nothing => throw $ (InternalError $ "Can't find function definition in tailRecOptim")
Just x => pure x
makeFusionBranch : Name -> List (Name, Nat) -> (Nat, FC, List Name, ImperativeStatement) ->
(ImperativeExp, ImperativeStatement)
makeFusionBranch fusionName functionsIdx (i, _, args, body) =
let newArgExp = map (\i => IEConstructorArg (cast i) (IEVar fusionArgsName)) [1..(length args)]
bodyRepArgs = replaceNamesExpS (zip args newArgExp) body
in (IEConstructorTag $ Left $ cast i, replaceExpS rep bodyRepArgs)
where
rep : ImperativeExp -> Maybe ImperativeExp
rep (IEApp (IEVar n) arg_vals) =
case lookup n functionsIdx of
Nothing => Nothing
Just i => Just $ IEApp
(IEVar fusionName)
[IEConstructor (Left $ cast i) arg_vals]
rep _ = Nothing
changeBodyToUseFusion : Name -> (Nat, Name, FC, List Name, ImperativeStatement) -> ImperativeStatement
changeBodyToUseFusion fusionName (i, n, (fc, args, _)) =
FunDecl fc n args (ReturnStatement $ IEApp (IEVar fusionName) [IEConstructor (Left $ cast i) (map IEVar args)])
tailRecOptimGroup : {auto c : Ref TailRecS TailSt} ->
SortedMap Name (FC, List Name, ImperativeStatement) ->
List Name -> Core ImperativeStatement
tailRecOptimGroup defs [] = pure neutral
tailRecOptimGroup defs [n] =
do
(fc, args , body) <- getDef defs n
pure $ FunDecl fc n args (makeTailOptimToBody n args body)
tailRecOptimGroup defs names =
do
fusionName <- genName
d <- traverse (getDef defs) names
let ids = [0..(length names `minus` 1)]
let alts = map (makeFusionBranch fusionName (zip names ids)) (zip ids d)
let fusionBody = SwitchStatement
(IEConstructorHead $ IEVar fusionArgsName)
alts
Nothing
let fusionArgs = [fusionArgsName]
let fusion = FunDecl EmptyFC fusionName fusionArgs (makeTailOptimToBody fusionName fusionArgs fusionBody)
let newFunctions = Prelude.concat $ map
(changeBodyToUseFusion fusionName)
(ids `List.zip` (names `List.zip` d))
pure $ fusion <+> newFunctions
export
tailRecOptim : ImperativeStatement -> ImperativeStatement
tailRecOptim (SeqStatement x y) = SeqStatement (tailRecOptim x) (tailRecOptim y)
tailRecOptim (FunDecl fc n args body) =
let new_body = if hasTailCall n body then makeTailOptimToBody n args body
else body
in FunDecl fc n args new_body
tailRecOptim x = x
tailRecOptim : ImperativeStatement -> Core ImperativeStatement
tailRecOptim x =
do
newRef TailRecS (MkTailSt 0)
let graph = tailCallGraph x
let groups = recursiveTailCallGroups graph
let functionsToOptimize = foldl SortedSet.union empty $ map SortedSet.fromList groups
let (unchanged, defs) = extractFunctions functionsToOptimize x
optimized <- traverse (tailRecOptimGroup defs) groups
pure $ unchanged <+> (concat optimized)

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,10 @@
module Main
import Data.Vect
import Data.Stream
foo : List Char
foo = unpack $ pack $ take 4000 (repeat 'a')
factorialAux : Integer -> Integer -> Integer
factorialAux 0 a = a
@ -15,3 +19,4 @@ main =
printLn $ factorial 100
printLn $ factorial 10000
printLn $ show $ the (Vect 3 String) ["red", "green", "blue"]
printLn foo