diff --git a/src/Compiler/ES/TailRec.idr b/src/Compiler/ES/TailRec.idr index 105131197..21fe81f9e 100644 --- a/src/Compiler/ES/TailRec.idr +++ b/src/Compiler/ES/TailRec.idr @@ -195,16 +195,13 @@ hasTailCalls g (x ::: Nil) = maybe False (contains x) $ lookup x g hasTailCalls _ _ = True -- Given a strongly connected group of functions, plus --- a unique index, convert them to a list of --- consisting of the function names plus the `TcGroup`, --- to which they belong. +-- a unique index, convert them to the `TcGroup` they belong to. toGroup : SortedMap Name (Name,List Name,NamedCExp) -> (Int,List1 Name) - -> List (Name,TcGroup) + -> TcGroup toGroup funMap (groupIndex,functions) = let ns = zipWithIndices $ forget functions - group = MkTcGroup groupIndex . fromList $ mapMaybe fun ns - in (,group) <$> forget functions + in MkTcGroup groupIndex . fromList $ mapMaybe fun ns where fun : (Int,Name) -> Maybe (Name,TcFunction) fun (fx, n) = do (_,args,exp) <- lookup n funMap @@ -212,16 +209,13 @@ toGroup funMap (groupIndex,functions) = -- Returns the connected components of the tail call graph -- of a set of toplevel function definitions. --- Every function name that is part of a tail-call group --- (a set of mutually tail-recursive functions) --- points to its corresponding group. -tailCallGroups : List (Name,List Name,NamedCExp) - -> SortedMap Name TcGroup +-- Every `TcGroup` consists of a set of mutually tail-recursive functions. +tailCallGroups : List (Name,List Name,NamedCExp) -> List TcGroup tailCallGroups funs = let funMap = M.fromList $ map (\t => (fst t,t)) funs graph = map (\(_,_,x) => tailCalls x) funMap groups = filter (hasTailCalls graph) $ tarjan graph - in fromList $ concatMap (toGroup funMap) (zipWithIndices groups) + in map (toGroup funMap) (zipWithIndices groups) -------------------------------------------------------------------------------- -- Converting tail call groups to expressions @@ -319,32 +313,36 @@ convertTcGroup loop g@(MkTcGroup gindex fs) = -- Tail recursion optimizations: Converts all groups of -- mutually tail recursive functions to an imperative loop. -tailRecOptim : SortedMap Name TcGroup +tailRecOptim : List TcGroup + -> (tcOptimized : SortedSet Name) -> (tcLoopName : Name) -> List (Name,List Name,NamedCExp) -> List Function -tailRecOptim groups loop ts = +tailRecOptim groups names loop ts = let regular = mapMaybe toFun ts - tailOpt = concatMap (convertTcGroup loop) $ values groups + tailOpt = concatMap (convertTcGroup loop) groups in tailOpt ++ regular where toFun : (Name,List Name,NamedCExp) -> Maybe Function - toFun (n,args,exp) = case lookup n groups of - Just _ => Nothing - Nothing => Just $ MkFunction n args exp + toFun (n,args,exp) = + if contains n names + then Nothing + else Just $ MkFunction n args exp ||| Converts a list of toplevel definitions (potentially ||| several groups of mutually tail-recursive functions) ||| to a new set of tail-call optimized function definitions. -||| `MkNmFun`s are converted. Other constructors of `NamedDef` +||| Only `MkNmFun`s are converted. Other constructors of `NamedDef` ||| are ignored and silently dropped. export functions : (tcLoopName : Name) -> List (Name,FC,NamedDef) -> List Function functions loop dfs = - let ts = mapMaybe def dfs - in tailRecOptim (tailCallGroups ts) loop ts + let ts = mapMaybe def dfs + groups = tailCallGroups ts + names = SortedSet.fromList $ concatMap (keys . functions) groups + in tailRecOptim groups names loop ts where def : (Name,FC,NamedDef) -> Maybe (Name,List Name,NamedCExp) def (n,_,MkNmFun args x) = Just (n,args,x) def _ = Nothing