mirror of
https://github.com/idris-lang/Idris2.git
synced 2024-11-28 02:23:44 +03:00
mutual tail recursion
This commit is contained in:
parent
d56b090c4c
commit
a5b2247f28
@ -211,9 +211,9 @@ compileToImperative c tm =
|
||||
cdata <- getCompileData Cases tm
|
||||
let ndefs = namedDefs cdata
|
||||
let ctm = forget (mainExpr cdata)
|
||||
s <- newRef Imps (MkImpSt 0)
|
||||
newRef Imps (MkImpSt 0)
|
||||
lst_defs <- traverse getImp (defsUsedByNamedCExp ctm ndefs)
|
||||
let defs = concat lst_defs
|
||||
let defs_optim = tailRecOptim defs
|
||||
defs_optim <- tailRecOptim defs
|
||||
(s, main) <- impExp False ctm
|
||||
pure $ (defs_optim, s <+> EvalExpStatement main)
|
||||
|
@ -14,7 +14,7 @@ mutual
|
||||
| IEPrimFnExt Name (List ImperativeExp)
|
||||
| IEConstructorHead ImperativeExp
|
||||
| IEConstructorTag (Either Int String)
|
||||
| IEConstructorArg Int ImperativeExp
|
||||
| IEConstructorArg Int ImperativeExp -- constructor arg index starts at 1
|
||||
| IEConstructor (Either Int String) (List ImperativeExp)
|
||||
| IEDelay ImperativeExp
|
||||
| IEForce ImperativeExp
|
||||
@ -80,64 +80,99 @@ mutual
|
||||
|
||||
mutual
|
||||
public export
|
||||
replaceNamesExp : List (Name, ImperativeExp) -> ImperativeExp -> ImperativeExp
|
||||
replaceNamesExp reps (IEVar n) =
|
||||
case lookup n reps of
|
||||
Nothing => IEVar n
|
||||
Just e => e
|
||||
replaceNamesExp reps (IELambda args body) =
|
||||
IELambda args $ replaceNamesExpS reps body
|
||||
replaceNamesExp reps (IEApp f args) =
|
||||
IEApp (replaceNamesExp reps f) (replaceNamesExp reps <$> args)
|
||||
replaceNamesExp reps (IEConstant c) =
|
||||
IEConstant c
|
||||
replaceNamesExp reps (IEPrimFn f args) =
|
||||
IEPrimFn f (replaceNamesExp reps <$> args)
|
||||
replaceNamesExp reps (IEPrimFnExt f args) =
|
||||
IEPrimFnExt f (replaceNamesExp reps <$> args)
|
||||
replaceNamesExp reps (IEConstructorHead e) =
|
||||
IEConstructorHead $ replaceNamesExp reps e
|
||||
replaceNamesExp reps (IEConstructorTag i) =
|
||||
IEConstructorTag i
|
||||
replaceNamesExp reps (IEConstructorArg i e) =
|
||||
IEConstructorArg i (replaceNamesExp reps e)
|
||||
replaceNamesExp reps (IEConstructor t args) =
|
||||
IEConstructor t (replaceNamesExp reps <$> args)
|
||||
replaceNamesExp reps (IEDelay e) =
|
||||
IEDelay $ replaceNamesExp reps e
|
||||
replaceNamesExp reps (IEForce e) =
|
||||
IEForce $ replaceNamesExp reps e
|
||||
replaceNamesExp reps IENull =
|
||||
IENull
|
||||
replaceExp : (ImperativeExp -> Maybe ImperativeExp) -> ImperativeExp -> ImperativeExp
|
||||
replaceExp rep x@(IEVar n) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => x
|
||||
replaceExp rep x@(IELambda args body) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IELambda args $ replaceExpS rep body
|
||||
replaceExp rep x@(IEApp f args) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEApp (replaceExp rep f) (replaceExp rep <$> args)
|
||||
replaceExp rep x@(IEConstant c) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => x
|
||||
replaceExp rep x@(IEPrimFn f args) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEPrimFn f (replaceExp rep <$> args)
|
||||
replaceExp rep x@(IEPrimFnExt f args) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEPrimFnExt f (replaceExp rep <$> args)
|
||||
replaceExp rep x@(IEConstructorHead e) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEConstructorHead $ replaceExp rep e
|
||||
replaceExp rep x@(IEConstructorTag i) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => x
|
||||
replaceExp rep x@(IEConstructorArg i e) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEConstructorArg i (replaceExp rep e)
|
||||
replaceExp rep x@(IEConstructor t args) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEConstructor t (replaceExp rep <$> args)
|
||||
replaceExp rep x@(IEDelay e) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEDelay $ replaceExp rep e
|
||||
replaceExp rep x@(IEForce e) =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => IEForce $ replaceExp rep e
|
||||
replaceExp rep x@IENull =
|
||||
case rep x of
|
||||
Just z => z
|
||||
Nothing => x
|
||||
|
||||
public export
|
||||
replaceNamesExpS : List (Name, ImperativeExp) -> ImperativeStatement -> ImperativeStatement
|
||||
replaceNamesExpS reps DoNothing =
|
||||
replaceExpS : (ImperativeExp -> Maybe ImperativeExp) -> ImperativeStatement -> ImperativeStatement
|
||||
replaceExpS rep DoNothing =
|
||||
DoNothing
|
||||
replaceNamesExpS reps (SeqStatement x y) =
|
||||
SeqStatement (replaceNamesExpS reps x) (replaceNamesExpS reps y)
|
||||
replaceNamesExpS reps (FunDecl fc n args body) =
|
||||
FunDecl fc n args $ replaceNamesExpS reps body
|
||||
replaceNamesExpS reps (ForeignDecl n path) =
|
||||
replaceExpS rep (SeqStatement x y) =
|
||||
SeqStatement (replaceExpS rep x) (replaceExpS rep y)
|
||||
replaceExpS rep (FunDecl fc n args body) =
|
||||
FunDecl fc n args $ replaceExpS rep body
|
||||
replaceExpS rep (ForeignDecl n path) =
|
||||
ForeignDecl n path
|
||||
replaceNamesExpS reps (ReturnStatement e) =
|
||||
ReturnStatement $ replaceNamesExp reps e
|
||||
replaceNamesExpS reps (SwitchStatement s alts def) =
|
||||
let s_ = replaceNamesExp reps s
|
||||
alts_ = (\(e,b) => (replaceNamesExp reps e, replaceNamesExpS reps b)) <$> alts
|
||||
def_ = replaceNamesExpS reps <$> def
|
||||
replaceExpS rep (ReturnStatement e) =
|
||||
ReturnStatement $ replaceExp rep e
|
||||
replaceExpS rep (SwitchStatement s alts def) =
|
||||
let s_ = replaceExp rep s
|
||||
alts_ = (\(e,b) => (replaceExp rep e, replaceExpS rep b)) <$> alts
|
||||
def_ = replaceExpS rep <$> def
|
||||
in SwitchStatement s_ alts_ def_
|
||||
replaceNamesExpS reps (LetDecl n v) =
|
||||
LetDecl n $ replaceNamesExp reps <$> v
|
||||
replaceNamesExpS reps (ConstDecl n v) =
|
||||
ConstDecl n $ replaceNamesExp reps v
|
||||
replaceNamesExpS reps (MutateStatement n v) =
|
||||
MutateStatement n $ replaceNamesExp reps v
|
||||
replaceNamesExpS reps (ErrorStatement s) =
|
||||
replaceExpS rep (LetDecl n v) =
|
||||
LetDecl n $ replaceExp rep <$> v
|
||||
replaceExpS rep (ConstDecl n v) =
|
||||
ConstDecl n $ replaceExp rep v
|
||||
replaceExpS rep (MutateStatement n v) =
|
||||
MutateStatement n $ replaceExp rep v
|
||||
replaceExpS rep (ErrorStatement s) =
|
||||
ErrorStatement s
|
||||
replaceNamesExpS reps (EvalExpStatement x) =
|
||||
EvalExpStatement $ replaceNamesExp reps x
|
||||
replaceNamesExpS reps (CommentStatement x) =
|
||||
replaceExpS rep (EvalExpStatement x) =
|
||||
EvalExpStatement $ replaceExp rep x
|
||||
replaceExpS rep (CommentStatement x) =
|
||||
CommentStatement x
|
||||
replaceNamesExpS reps (ForEverLoop x) =
|
||||
ForEverLoop $ replaceNamesExpS reps x
|
||||
replaceExpS rep (ForEverLoop x) =
|
||||
ForEverLoop $ replaceExpS rep x
|
||||
|
||||
|
||||
rep : List (Name, ImperativeExp) -> ImperativeExp -> Maybe ImperativeExp
|
||||
rep reps (IEVar n) =
|
||||
lookup n reps
|
||||
rep _ _ = Nothing
|
||||
|
||||
public export
|
||||
replaceNamesExpS : List (Name, ImperativeExp) -> ImperativeStatement -> ImperativeStatement
|
||||
replaceNamesExpS reps x =
|
||||
replaceExpS (rep reps) x
|
||||
|
@ -1,26 +1,52 @@
|
||||
module Compiler.ES.TailRec
|
||||
|
||||
import Data.Maybe
|
||||
import Data.List
|
||||
import Data.Strings
|
||||
import Data.SortedSet
|
||||
import Data.SortedMap
|
||||
import Core.Name
|
||||
import Core.Context
|
||||
import Compiler.ES.ImperativeAst
|
||||
|
||||
hasTailCall : Name -> ImperativeStatement -> Bool
|
||||
hasTailCall n (SeqStatement x y) =
|
||||
hasTailCall n y
|
||||
hasTailCall n (ReturnStatement x) =
|
||||
import Debug.Trace
|
||||
|
||||
data TailRecS : Type where
|
||||
|
||||
record TailSt where
|
||||
constructor MkTailSt
|
||||
nextName : Int
|
||||
|
||||
genName : {auto c : Ref TailRecS TailSt} -> Core Name
|
||||
genName =
|
||||
do
|
||||
s <- get TailRecS
|
||||
let i = nextName s
|
||||
put TailRecS (record { nextName = i + 1 } s)
|
||||
pure $ MN "imp_gen_tailoptim" i
|
||||
|
||||
allTailCalls : ImperativeStatement -> SortedSet Name
|
||||
allTailCalls (SeqStatement x y) =
|
||||
allTailCalls y
|
||||
allTailCalls (ReturnStatement x) =
|
||||
case x of
|
||||
IEApp (IEVar cn) _ => n == cn
|
||||
_ => False
|
||||
hasTailCall n (SwitchStatement e alts d) =
|
||||
(any (\(_, w)=>hasTailCall n w) alts) || (maybe False (hasTailCall n) d)
|
||||
hasTailCall n (ForEverLoop x) =
|
||||
hasTailCall n x
|
||||
hasTailCall n o = False
|
||||
IEApp (IEVar n) _ => insert n empty
|
||||
_ => empty
|
||||
allTailCalls (SwitchStatement e alts d) =
|
||||
maybe empty allTailCalls d `union` concat (map allTailCalls (map snd alts))
|
||||
allTailCalls (ForEverLoop x) =
|
||||
allTailCalls x
|
||||
allTailCalls o = empty
|
||||
|
||||
|
||||
argsName : Name
|
||||
argsName = MN "tailRecOptimArgs" 0
|
||||
argsName = MN "imp_gen_tailoptim_Args" 0
|
||||
|
||||
stepFnName : Name
|
||||
stepFnName = MN "tailRecOptimStep" 0
|
||||
stepFnName = MN "imp_gen_tailoptim_step" 0
|
||||
|
||||
fusionArgsName : Name
|
||||
fusionArgsName = MN "imp_gen_tailoptim_fusion_args" 0
|
||||
|
||||
createNewArgs : List ImperativeExp -> ImperativeExp
|
||||
createNewArgs values =
|
||||
@ -40,8 +66,6 @@ replaceTailCall n (ReturnStatement x) =
|
||||
if n == cn then ReturnStatement $ createNewArgs arg_vals
|
||||
else defRet
|
||||
_ => defRet
|
||||
|
||||
|
||||
replaceTailCall n (SwitchStatement e alts d) =
|
||||
SwitchStatement e (map (\(x,y) => (x, replaceTailCall n y)) alts) (map (replaceTailCall n) d)
|
||||
replaceTailCall n (ForEverLoop x) =
|
||||
@ -60,11 +84,169 @@ makeTailOptimToBody n argNames body =
|
||||
loop = ForEverLoop $ SwitchStatement (IEConstructorHead $ IEVar argsName) [(IEConstructorTag $ Left 0, loopRec)] (Just loopReturn)
|
||||
in stepFn <+> createArgInit argNames <+> loop
|
||||
|
||||
record CallGraph where
|
||||
constructor MkCallGraph
|
||||
calls, callers : SortedMap Name (SortedSet Name)
|
||||
|
||||
showCallGraph : CallGraph -> String
|
||||
showCallGraph x =
|
||||
let calls = unlines $ map
|
||||
(\(x,y) => show x ++ ": " ++ show (SortedSet.toList y))
|
||||
(SortedMap.toList x.calls)
|
||||
callers = unlines $ map
|
||||
(\(x,y) => show x ++ ": " ++ show (SortedSet.toList y))
|
||||
(SortedMap.toList x.callers)
|
||||
in calls ++ "\n----\n" ++ callers
|
||||
|
||||
|
||||
tailCallGraph : ImperativeStatement -> CallGraph
|
||||
tailCallGraph (SeqStatement x y) =
|
||||
let xg = tailCallGraph x
|
||||
yg = tailCallGraph y
|
||||
in MkCallGraph
|
||||
(mergeWith union xg.calls yg.calls)
|
||||
(mergeWith union xg.callers yg.callers)
|
||||
tailCallGraph (FunDecl fc n args body) =
|
||||
let calls = allTailCalls body
|
||||
in MkCallGraph (insert n calls empty) (SortedMap.fromList $ map (\x => (x, SortedSet.insert n empty)) (SortedSet.toList calls))
|
||||
tailCallGraph _ = MkCallGraph empty empty
|
||||
|
||||
kosarajuL : SortedSet Name -> List Name -> CallGraph -> (SortedSet Name, List Name)
|
||||
kosarajuL visited [] graph =
|
||||
(visited, [])
|
||||
kosarajuL visited (x::xs) graph =
|
||||
if contains x visited then kosarajuL visited xs graph
|
||||
else let outNeighbours = maybe [] SortedSet.toList $ lookup x (graph.calls)
|
||||
(visited_, l_) = kosarajuL (insert x visited) (toList outNeighbours) graph
|
||||
(visited__, l__) = kosarajuL visited_ xs graph
|
||||
in (visited__, l__ ++ (x :: l_))
|
||||
|
||||
|
||||
|
||||
kosarajuAssign : CallGraph -> Name -> Name -> SortedMap Name Name -> SortedMap Name Name
|
||||
kosarajuAssign graph u root s =
|
||||
case lookup u s of
|
||||
Just _ => s
|
||||
Nothing => let inNeighbours = maybe [] SortedSet.toList $ lookup u (graph.callers)
|
||||
in foldl (\acc, elem => kosarajuAssign graph elem root acc) (insert u root s) inNeighbours
|
||||
|
||||
kosaraju: CallGraph -> SortedMap Name Name
|
||||
kosaraju graph =
|
||||
let l = snd $ kosarajuL empty (keys $ graph.calls) graph
|
||||
in foldl (\acc, elem => kosarajuAssign graph elem elem acc) empty l
|
||||
|
||||
groupBy : (a -> a -> Bool) -> List a -> List (List a)
|
||||
groupBy _ [] = []
|
||||
groupBy p' (x'::xs') =
|
||||
let (ys',zs') = go p' x' xs'
|
||||
in (x' :: ys') :: zs'
|
||||
where
|
||||
go : (a -> a -> Bool) -> a -> List a -> (List a, List (List a))
|
||||
go p z (x::xs) =
|
||||
let (ys,zs) = go p x xs
|
||||
in case p z x of
|
||||
True => (x :: ys, zs)
|
||||
False => ([], (x :: ys) :: zs)
|
||||
go _ _ [] = ([], [])
|
||||
|
||||
recursiveTailCallGroups : CallGraph -> List (List Name)
|
||||
recursiveTailCallGroups graph =
|
||||
let roots = kosaraju graph
|
||||
groups = map (map fst) $
|
||||
groupBy (\x,y => Builtin.snd x == Builtin.snd y) $
|
||||
sortBy (\x,y=> compare (snd x) (snd y)) $
|
||||
toList roots
|
||||
in [x | x<-groups, hasTailCalls x]
|
||||
where
|
||||
hasTailCalls : List Name -> Bool
|
||||
hasTailCalls [] =
|
||||
False
|
||||
hasTailCalls [x] =
|
||||
case lookup x (graph.calls) of
|
||||
Nothing =>
|
||||
False
|
||||
Just s =>
|
||||
case SortedSet.toList s of
|
||||
[n] => n == x
|
||||
_ => False
|
||||
hasTailCalls _ =
|
||||
True
|
||||
|
||||
|
||||
extractFunctions : SortedSet Name -> ImperativeStatement ->
|
||||
(ImperativeStatement, SortedMap Name (FC, List Name, ImperativeStatement))
|
||||
extractFunctions toExtract (SeqStatement x y) =
|
||||
let (xs, xf) = extractFunctions toExtract x
|
||||
(ys, yf) = extractFunctions toExtract y
|
||||
in (xs <+> ys, mergeLeft xf yf)
|
||||
extractFunctions toExtract f@(FunDecl fc n args body) =
|
||||
if contains n toExtract then (neutral, insert n (fc, args, body) empty)
|
||||
else (f, empty)
|
||||
extractFunctions toExtract x =
|
||||
(x, empty)
|
||||
|
||||
getDef : SortedMap Name (FC, List Name, ImperativeStatement) -> Name -> Core (FC, List Name, ImperativeStatement)
|
||||
getDef defs n =
|
||||
case lookup n defs of
|
||||
Nothing => throw $ (InternalError $ "Can't find function definition in tailRecOptim")
|
||||
Just x => pure x
|
||||
|
||||
|
||||
makeFusionBranch : Name -> List (Name, Nat) -> (Nat, FC, List Name, ImperativeStatement) ->
|
||||
(ImperativeExp, ImperativeStatement)
|
||||
makeFusionBranch fusionName functionsIdx (i, _, args, body) =
|
||||
let newArgExp = map (\i => IEConstructorArg (cast i) (IEVar fusionArgsName)) [1..(length args)]
|
||||
bodyRepArgs = replaceNamesExpS (zip args newArgExp) body
|
||||
in (IEConstructorTag $ Left $ cast i, replaceExpS rep bodyRepArgs)
|
||||
where
|
||||
rep : ImperativeExp -> Maybe ImperativeExp
|
||||
rep (IEApp (IEVar n) arg_vals) =
|
||||
case lookup n functionsIdx of
|
||||
Nothing => Nothing
|
||||
Just i => Just $ IEApp
|
||||
(IEVar fusionName)
|
||||
[IEConstructor (Left $ cast i) arg_vals]
|
||||
rep _ = Nothing
|
||||
|
||||
changeBodyToUseFusion : Name -> (Nat, Name, FC, List Name, ImperativeStatement) -> ImperativeStatement
|
||||
changeBodyToUseFusion fusionName (i, n, (fc, args, _)) =
|
||||
FunDecl fc n args (ReturnStatement $ IEApp (IEVar fusionName) [IEConstructor (Left $ cast i) (map IEVar args)])
|
||||
|
||||
tailRecOptimGroup : {auto c : Ref TailRecS TailSt} ->
|
||||
SortedMap Name (FC, List Name, ImperativeStatement) ->
|
||||
List Name -> Core ImperativeStatement
|
||||
tailRecOptimGroup defs [] = pure neutral
|
||||
tailRecOptimGroup defs [n] =
|
||||
do
|
||||
(fc, args , body) <- getDef defs n
|
||||
pure $ FunDecl fc n args (makeTailOptimToBody n args body)
|
||||
tailRecOptimGroup defs names =
|
||||
do
|
||||
fusionName <- genName
|
||||
d <- traverse (getDef defs) names
|
||||
let ids = [0..(length names `minus` 1)]
|
||||
let alts = map (makeFusionBranch fusionName (zip names ids)) (zip ids d)
|
||||
let fusionBody = SwitchStatement
|
||||
(IEConstructorHead $ IEVar fusionArgsName)
|
||||
alts
|
||||
Nothing
|
||||
let fusionArgs = [fusionArgsName]
|
||||
let fusion = FunDecl EmptyFC fusionName fusionArgs (makeTailOptimToBody fusionName fusionArgs fusionBody)
|
||||
let newFunctions = Prelude.concat $ map
|
||||
(changeBodyToUseFusion fusionName)
|
||||
(ids `List.zip` (names `List.zip` d))
|
||||
pure $ fusion <+> newFunctions
|
||||
|
||||
|
||||
|
||||
export
|
||||
tailRecOptim : ImperativeStatement -> ImperativeStatement
|
||||
tailRecOptim (SeqStatement x y) = SeqStatement (tailRecOptim x) (tailRecOptim y)
|
||||
tailRecOptim (FunDecl fc n args body) =
|
||||
let new_body = if hasTailCall n body then makeTailOptimToBody n args body
|
||||
else body
|
||||
in FunDecl fc n args new_body
|
||||
tailRecOptim x = x
|
||||
tailRecOptim : ImperativeStatement -> Core ImperativeStatement
|
||||
tailRecOptim x =
|
||||
do
|
||||
newRef TailRecS (MkTailSt 0)
|
||||
let graph = tailCallGraph x
|
||||
let groups = recursiveTailCallGroups graph
|
||||
let functionsToOptimize = foldl SortedSet.union empty $ map SortedSet.fromList groups
|
||||
let (unchanged, defs) = extractFunctions functionsToOptimize x
|
||||
optimized <- traverse (tailRecOptimGroup defs) groups
|
||||
pure $ unchanged <+> (concat optimized)
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,10 @@
|
||||
module Main
|
||||
|
||||
import Data.Vect
|
||||
import Data.Stream
|
||||
|
||||
foo : List Char
|
||||
foo = unpack $ pack $ take 4000 (repeat 'a')
|
||||
|
||||
factorialAux : Integer -> Integer -> Integer
|
||||
factorialAux 0 a = a
|
||||
@ -15,3 +19,4 @@ main =
|
||||
printLn $ factorial 100
|
||||
printLn $ factorial 10000
|
||||
printLn $ show $ the (Vect 3 String) ["red", "green", "blue"]
|
||||
printLn foo
|
||||
|
Loading…
Reference in New Issue
Block a user