diff --git a/grin/grin.cabal b/grin/grin.cabal index b98c04f0..5750bccd 100644 --- a/grin/grin.cabal +++ b/grin/grin.cabal @@ -146,10 +146,23 @@ library Transformations.ExtendedSyntax.GenerateEval Transformations.ExtendedSyntax.MangleNames Transformations.ExtendedSyntax.StaticSingleAssignment + Transformations.ExtendedSyntax.Optimising.ArityRaising + Transformations.ExtendedSyntax.Optimising.CaseCopyPropagation + Transformations.ExtendedSyntax.Optimising.CaseHoisting Transformations.ExtendedSyntax.Optimising.CopyPropagation + Transformations.ExtendedSyntax.Optimising.ConstantPropagation Transformations.ExtendedSyntax.Optimising.CSE + Transformations.ExtendedSyntax.Optimising.DeadDataElimination Transformations.ExtendedSyntax.Optimising.DeadFunctionElimination + Transformations.ExtendedSyntax.Optimising.DeadParameterElimination Transformations.ExtendedSyntax.Optimising.EvaluatedCaseElimination + Transformations.ExtendedSyntax.Optimising.Inlining + Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing + Transformations.ExtendedSyntax.Optimising.NonSharedElimination + Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionElimination + Transformations.ExtendedSyntax.Optimising.SimpleDeadParameterElimination + Transformations.ExtendedSyntax.Optimising.SimpleDeadVariableElimination + Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisation Transformations.ExtendedSyntax.Optimising.TrivialCaseElimination Transformations.BindNormalisation @@ -302,10 +315,22 @@ test-suite grin-test Transformations.ExtendedSyntax.ConversionSpec Transformations.ExtendedSyntax.MangleNamesSpec Transformations.ExtendedSyntax.StaticSingleAssignmentSpec + Transformations.ExtendedSyntax.Optimising.ArityRaisingSpec + Transformations.ExtendedSyntax.Optimising.CaseCopyPropagationSpec + Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec Transformations.ExtendedSyntax.Optimising.CopyPropagationSpec Transformations.ExtendedSyntax.Optimising.CSESpec + Transformations.ExtendedSyntax.Optimising.DeadDataEliminationSpec Transformations.ExtendedSyntax.Optimising.DeadFunctionEliminationSpec + Transformations.ExtendedSyntax.Optimising.DeadParameterEliminationSpec Transformations.ExtendedSyntax.Optimising.EvaluatedCaseEliminationSpec + Transformations.ExtendedSyntax.Optimising.InliningSpec + Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxingSpec + Transformations.ExtendedSyntax.Optimising.NonSharedEliminationSpec + Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionEliminationSpec + Transformations.ExtendedSyntax.Optimising.SimpleDeadParameterEliminationSpec + Transformations.ExtendedSyntax.Optimising.SimpleDeadVariableEliminationSpec + Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisationSpec Transformations.ExtendedSyntax.Optimising.TrivialCaseEliminationSpec Transformations.Simplifying.RegisterIntroductionSpec diff --git a/grin/src/Transformations/ExtendedSyntax/Conversion.hs b/grin/src/Transformations/ExtendedSyntax/Conversion.hs index ba6358bc..5b7b89ea 100644 --- a/grin/src/Transformations/ExtendedSyntax/Conversion.hs +++ b/grin/src/Transformations/ExtendedSyntax/Conversion.hs @@ -198,7 +198,7 @@ instance Convertible Exp New.Exp where -} (EBind lhs (ConstTagNode tag args) rhs) -> do - asPatName <- deriveNewName "a" + asPatName <- deriveNewName "conv" newNodePat <- oldNodePatToAsPat tag args asPatName pure $ New.EBindF lhs newNodePat rhs (EBind lhs (Var var) rhs) diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/ArityRaising.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/ArityRaising.hs new file mode 100644 index 00000000..7f2aa41a --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/ArityRaising.hs @@ -0,0 +1,207 @@ +{-# LANGUAGE LambdaCase, TupleSections #-} +module Transformations.ExtendedSyntax.Optimising.ArityRaising where + +import Data.List (nub) +import Data.Maybe (fromJust, isJust, mapMaybe, catMaybes) +import Data.Functor.Foldable +import qualified Data.Set as Set; import Data.Set (Set) +import qualified Data.Map.Strict as Map; import Data.Map (Map) +import qualified Data.Vector as Vector; import Data.Vector (Vector) + +import Control.Monad.State.Strict + +import Grin.ExtendedSyntax.Grin (packName, unpackName) +import Grin.ExtendedSyntax.Syntax +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Names + + +{- +1. Select one function which has a parameter of a pointer to one constructor only. +2. If the parameter is linear and fetched in the function body then this is a good function for + arity raising + +How to raise arity? +1. Change the function parameters: replace the parameter with the parameters in the constructor +2. Change the function body: remove the fectch and use the variables as parameters +3. Change the caller sides: instead of passing the pointer fetch the pointer and pass the values are parameters + +How to handle self recursion? +1. If a function is self recursive, the paramter that is fetched originaly in the function body + must be passed as normal parameters in the same function call. + +Phase 1: Select a function and a parameter to transform. +Phase 2: Transform the parameter and the function body. +Phase 3: Transform the callers. + +This way the fetches propagates slowly to the caller side to the creational point. + +Parameters: + - Used only in fetch or in recursive calls for the same function. + - Its value points to a location, which location has only one Node with at least one parameter +-} + +-- TODO: True is reported even if exp stayed the same. Investigate why exp stay the same +-- for non-null arity data. +arityRaising :: Int -> TypeEnv -> Exp -> (Exp, ExpChanges) +arityRaising n te exp = if Map.null arityData then (exp, NoChange) else (phase2 n arityData exp, NewNames) + where + arityData = phase1 te exp + +-- | ArityData maps a function name to its arguments that can be arity raised. +-- 1st: Name of the argument +-- 2nd: The index of the argument +-- 3rd: The tag and one possible locaition where the parameter can point to. +type ArityData = Map Name [(Name, Int, (Tag, Int))] + +type ParameterInfo = Map Name (Int, (Tag, Int)) + +data Phase1Data + = ProgramData { pdArityData :: ArityData } + | FunData { fdArityData :: ArityData } + | BodyData { bdFunCall :: [(Name, Name)] + , bdFetch :: Map Name Int + , bdOther :: [Name] + } + deriving (Show) + +instance Semigroup Phase1Data where + (ProgramData ad0) <> (ProgramData ad1) = ProgramData (Map.unionWith mappend ad0 ad1) + (FunData fd0) <> (FunData fd1) = FunData (mappend fd0 fd1) + (BodyData c0 f0 o0) <> (BodyData c1 f1 o1) = BodyData (c0 ++ c1) (Map.unionWith (+) f0 f1) (o0 ++ o1) + +instance Monoid Phase1Data where + mempty = BodyData mempty mempty mempty + +variableInVar :: Val -> [Name] +variableInVar (Var v) = [v] +variableInVar _ = [] + +variableInNode :: Val -> [Name] +variableInNode (ConstTagNode _ vs) = vs +variableInNode _ = [] + +variableInNodes :: [Val] -> [Name] +variableInNodes = concatMap variableInNode + +phase1 :: TypeEnv -> Exp -> ArityData +phase1 te = pdArityData . cata collect where + collect :: ExpF Phase1Data -> Phase1Data + collect = \case + SAppF fn ps -> mempty { bdFunCall = map (fn,) ps, bdOther = ps } + SFetchF var -> mempty { bdFetch = Map.singleton var 1 } + SUpdateF ptr var -> mempty { bdOther = [ptr, var] } + SReturnF val -> mempty { bdOther = variableInNode val ++ variableInVar val } + SStoreF v -> mempty { bdOther = [v] } + SBlockF ad -> ad + AltF _ _ ad -> ad + ECaseF scrut alts -> mconcat alts <> mempty { bdOther = [scrut] } + EBindF lhs _ rhs -> lhs <> rhs + + -- Keep the parameters that are locations and points to a single node with at least one parameters + -- - that are not appear in others + -- - that are not appear in other function calls + -- - that are fetched at least once + DefF fn ps body -> + let funData = + [ (p,i,(fromJust mtag)) + | (p,i) <- ps `zip` [1..] + , Map.member p (bdFetch body) + , let mtag = pointsToOneNode te p + , isJust mtag + , p `notElem` (bdOther body) + , p `notElem` (snd <$> (filter ((/=fn) . fst) (bdFunCall body))) + ] + in FunData $ case funData of + [] -> Map.empty + _ -> Map.singleton fn funData + + ProgramF exts defs -> ProgramData $ Map.unionsWith mappend (fdArityData <$> defs) + +pointsToOneNode :: TypeEnv -> Name -> Maybe (Tag, Int) +pointsToOneNode te var = case Map.lookup var (_variable te) of + (Just (T_SimpleType (T_Location locs))) -> case nub $ concatMap Map.keys $ ((_location te) Vector.!) <$> locs of + [tag] -> Just (tag, Vector.length $ head $ Map.elems $ (_location te) Vector.! (head locs)) + _ -> Nothing + _ -> Nothing + +type VarM a = StateT Int NameM a + +evalVarM :: Int -> Exp -> VarM a -> a +evalVarM n exp = fst . evalNameM exp . flip evalStateT n + +{- +Phase2 and Phase3 can be implemented in one go. + +Change only the functions which are in the ArityData map, left the others out. + * Change fetches to pure, using the tag information provided + * Change funcall parameters + * Change fundef parameters + +Use the original parameter name with new indices, thus we dont need a name generator. +-} +phase2 :: Int -> ArityData -> Exp -> Exp +phase2 n arityData exp = evalVarM 0 exp $ cata change exp where + fetchParNames :: Name -> Int -> Int -> [Name] + fetchParNames nm idx i = (\j -> packName $ concat [unpackName nm,".",show n,".",show idx,".arity.",show j]) <$> [1..i] + + newParNames :: Name -> Int -> [Name] + newParNames nm i = (\j -> packName $ concat [unpackName nm,".",show n,".arity.",show j]) <$> [1..i] + + parameterInfo :: ParameterInfo + parameterInfo = Map.fromList $ map (\(n,ith,tag) -> (n, (ith, tag))) $ concat $ Map.elems arityData + + replace_parameters_with_new_ones = concatMap $ \case + p | Just (nth, (tag, ps)) <- Map.lookup p parameterInfo -> + newParNames p ps + | otherwise -> [p] + + change :: ExpF (VarM Exp) -> (VarM Exp) + change = \case + {- Change only function bodies that are in the ArityData + from: (CNode c1 cn) <- fetch pi + to: (CNode c1 cn) <- pure (CNode pi1 pin) + + from: funcall p1 pi pn + to: rec-funcall p1 pi1 pin pn + to: do (CNode c1 cn) <- fetch pi + non-rec-funcall p1 c1 cn pn + + from: fundef p1 pi pn + to: fundef p1 pi1 pin pn + -} + SFetchF var + | Just (nth, (tag, ps)) <- Map.lookup var parameterInfo -> + pure $ SReturn (ConstTagNode tag (newParNames var ps)) + | otherwise -> + pure $ SFetch var + + SAppF f fps + | Just aritedParams <- Map.lookup f arityData -> do + idx <- get + let qsi = Map.fromList $ map (\(_,i,t) -> (i,t)) aritedParams + nsi = Map.fromList $ map (\(n,i,t) -> (n,t)) aritedParams + psi = [1..] `zip` fps + newPs = flip concatMap psi $ \case + (_, n) | Just (t, jth) <- Map.lookup n nsi -> newParNames n jth + (i, n) | Just (t, jth) <- Map.lookup i qsi -> fetchParNames n idx jth + -- (i, Undefined{}) | Just (_, jth) <- Map.lookup i qsi -> replicate jth (Undefined dead_t) + -- (_, other) -> [other] + fetches <- fmap catMaybes $ forM psi $ \case + (_, n) | Just _ <- Map.lookup n nsi -> pure Nothing + (i, n) | Just (t, jth) <- Map.lookup i qsi -> do + asPatName <- lift deriveWildCard + pure $ Just (AsPat t (fetchParNames n idx jth) asPatName, SFetch n) + _ -> pure Nothing + put (idx + 1) + pure $ case fetches of + [] -> SApp f newPs + _ -> SBlock $ foldr (\(pat, fetch) rest -> EBind fetch pat rest) (SApp f newPs) fetches + | otherwise -> + pure $ SApp f fps + + DefF f ps new + | Map.member f arityData -> Def f (replace_parameters_with_new_ones ps) <$> new + | otherwise -> Def f ps <$> new + + rest -> embed <$> sequence rest diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagation.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagation.hs new file mode 100644 index 00000000..a925ed26 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagation.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE LambdaCase #-} +module Transformations.ExtendedSyntax.Optimising.CaseCopyPropagation where + +import Data.Map (Map) +import Data.Functor.Foldable + +import qualified Data.Map as Map + +import Control.Monad.State + +import Grin.ExtendedSyntax.Grin +import Transformations.ExtendedSyntax.Names +import Transformations.ExtendedSyntax.Util (cataM) + + +-- NOTE: ~ Maybe Tag +data TagInfo = Unknown | Known Tag + deriving (Eq, Ord, Show) + +-- | Maps alt names to TagInfo +type InfoTable = Map Name TagInfo + +-- NOTE: Case Copy Propagtion ~ Case Unboxing +caseCopyPropagation :: Exp -> (Exp, ExpChanges) +caseCopyPropagation e = rebindCases infoTable e where + infoTable = collectTagInfo e + +-- | Collects tag information about case alternatives. +collectTagInfo :: Exp -> InfoTable +collectTagInfo = flip execState mempty . cataM alg where + + alg :: ExpF TagInfo -> State InfoTable TagInfo + alg = \case + SBlockF tagInfo -> pure tagInfo + EBindF _ _ rhsTagInfo -> pure rhsTagInfo + ECaseF scrut altTagInfo -> pure $ commonTag altTagInfo + SReturnF (ConstTagNode tag [arg]) -> pure $ Known tag + + AltF _ name tagInfo -> do + modify (Map.insert name tagInfo) + pure tagInfo + + _ -> pure Unknown + +-- | Rebinds unboxable case expressions, and unboxes +-- the corresponding alternatives' last return expressions. +rebindCases :: InfoTable -> Exp -> (Exp, ExpChanges) +rebindCases infoTable e = evalNameM e $ cataM alg e where + + alg :: ExpF Exp -> NameM Exp + alg = \case + ECaseF scrut alts + | Known tag <- lookupCommonTag [ name | Alt _ name _ <- alts ] + , alts' <- [ Alt cpat name (unboxLastReturn body) | Alt cpat name body <- alts ] + , case' <- ECase scrut alts' + -> do + res <- deriveNewName "ccp" + pure $ SBlock $ EBind case' (VarPat res) (SReturn $ ConstTagNode tag [res]) + e -> pure $ embed e + + -- | Determine the common tag for a set of alternatives (if it exists). + lookupCommonTag :: [Name] -> TagInfo + lookupCommonTag = + commonTag + . map (\alt -> Map.findWithDefault Unknown alt infoTable) + +-- | Unboxes the last node-returning expression in a binding sequence. +unboxLastReturn :: Exp -> Exp +unboxLastReturn = apo coAlg where + + coAlg :: Exp -> ExpF (Either Exp Exp) + coAlg = \case + SReturn (ConstTagNode _ [arg]) -> SReturnF (Var arg) + EBind lhs bPat rhs -> EBindF (Left lhs) bPat (Right rhs) + SBlock body -> SBlockF (Right body) + e -> Left <$> project e + +commonTag :: [TagInfo] -> TagInfo +commonTag (t : ts) + | all (==t) ts = t +commonTag _ = Unknown diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs new file mode 100644 index 00000000..f29f1090 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs @@ -0,0 +1,123 @@ +{-# LANGUAGE LambdaCase, TupleSections #-} +module Transformations.ExtendedSyntax.Optimising.CaseHoisting where + +import Control.Monad +import Control.Comonad +import Control.Comonad.Cofree +import Data.Functor.Foldable as Foldable +import qualified Data.Foldable + +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Set (Set) +import qualified Data.Set as Set +import qualified Data.Vector as Vector +import Data.Bifunctor (first) + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Util +import Transformations.ExtendedSyntax.Names + +{- + IDEA: + If Alt had name then the HPT could calculate it's return type and store in TypeEnv +-} + +getReturnTagSet :: TypeEnv -> Exp -> Maybe (Set Tag) +getReturnTagSet typeEnv = cata folder where + folder exp = case exp of + EBindF _ _ ts -> ts + SBlockF ts -> ts + AltF _ _ ts -> ts + ECaseF _ alts -> mconcat <$> sequence alts + + SReturnF val + | Just (T_NodeSet ns) <- mTypeOfValTE typeEnv val + -> Just (Map.keysSet ns) + + SAppF name _ + | T_NodeSet ns <- fst $ functionType typeEnv name + -> Just (Map.keysSet ns) + + SFetchF name + | T_SimpleType (T_Location locs) <- variableType typeEnv name + -> Just (mconcat [Map.keysSet (_location typeEnv Vector.! loc) | loc <- locs]) + + _ -> Nothing + + +caseHoisting :: TypeEnv -> Exp -> (Exp, ExpChanges) +caseHoisting typeEnv exp = first fst $ evalNameM exp $ histoM folder exp where + + folder :: ExpF (Cofree ExpF (Exp, Set Name)) -> NameM (Exp, Set Name) + folder exp = case exp of + -- middle case + EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) + (_ :< (EBindF ((ECase varName alts2, caseUse) :< _) lpat ((rightExp, rightUse) :< _))) + | lpatName == varName + , Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 + , Just matchList <- disjointMatch (zip alts1Types alts1) alts2 + , Set.notMember varName rightUse -- allow only linear variables ; that are not used later + -> do + hoistedAlts <- mapM (hoistAlts lpatName) matchList + pure (EBind (ECase val hoistedAlts) lpat rightExp, Set.delete varName $ mconcat [leftUse, caseUse, rightUse]) + + -- last case + EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) ((ECase varName alts2, rightUse) :< _) + | lpatName == varName + , Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 + , Just matchList <- disjointMatch (zip alts1Types alts1) alts2 + -> do + hoistedAlts <- mapM (hoistAlts lpatName) matchList + pure (ECase val hoistedAlts, Set.delete varName $ mconcat [leftUse, rightUse]) + + _ -> let useSub = Data.Foldable.fold (snd . extract <$> exp) + useExp = foldNameUseExpF Set.singleton exp + in pure (embed (fst . extract <$> exp), mconcat [useSub, useExp]) + +hoistAlts :: Name -> (Alt, Alt) -> NameM Alt +hoistAlts lpatName (Alt cpat1 altName1 alt1, Alt cpat2 altName2 alt2) = do + freshLPatName <- deriveNewName lpatName + let nameMap = Map.singleton lpatName freshLPatName + (freshAlt2, _) <- refreshNames nameMap $ + EBind (SReturn $ Var freshLPatName) (VarPat altName2) alt2 + pure . Alt cpat1 altName1 $ EBind (SBlock alt1) (VarPat freshLPatName) freshAlt2 + +disjointMatch :: [(Set Tag, Alt)] -> [Alt] -> Maybe [(Alt, Alt)] +disjointMatch tsAlts1 alts2 + | Just (defaults, tagMap) <- mconcat <$> mapM groupByCPats alts2 + , length defaults <= 1 + , Just (altPairs, _, _) <- Data.Foldable.foldrM (matchAlt tagMap) ([], defaults, Set.empty) tsAlts1 + = Just altPairs +disjointMatch _ _ = Nothing + +groupByCPats :: Alt -> Maybe ([Alt], Map Tag Alt) +groupByCPats alt@(Alt cpat _ _) = case cpat of + DefaultPat -> Just ([alt], mempty) + NodePat tag _ -> Just ([], Map.singleton tag alt) + _ -> Nothing + +matchAlt :: Map Tag Alt -> (Set Tag, Alt) -> ([(Alt, Alt)], [Alt], Set Tag) -> Maybe ([(Alt, Alt)], [Alt], Set Tag) +matchAlt tagMap (ts, alt1) (matchList, defaults, coveredTags) + -- regular node pattern + | Set.size ts == 1 + , tag <- Set.findMin ts + , Set.notMember tag coveredTags + , Just alt2 <- Map.lookup tag tagMap + = Just ((alt1, alt2):matchList, defaults, Set.insert tag coveredTags) + + -- default can handle this + | defaultAlt:[] <- defaults + , Data.Foldable.all (flip Set.notMember coveredTags) ts + = Just ((alt1, defaultAlt):matchList, [], coveredTags `mappend` ts) + + | otherwise = Nothing + +{- + TODO: + - add cloned variables to TypeEnv + done - ignore non linear scrutinee + IDEA: + this could be supported if product type was available in GRIN then the second case could return from the hoisted case with a pair of the original two case results +-} diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/ConstantPropagation.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/ConstantPropagation.hs new file mode 100644 index 00000000..f526b3f6 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/ConstantPropagation.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE LambdaCase, TupleSections, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.ConstantPropagation where + + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.Functor.Foldable + +import Lens.Micro ((^.)) + +import Grin.ExtendedSyntax.Grin +import Transformations.ExtendedSyntax.Util + +{- + HINT: + propagates only tag values but not literals + GRIN is not a supercompiler + + NOTE: + We only need the tag information to simplify case expressions. + This means that Env could be a Name -> Tag mapping. +-} + +type Env = Map Name Val + +constantPropagation :: Exp -> Exp +constantPropagation e = ana builder (mempty, e) where + + builder :: (Env, Exp) -> ExpF (Env, Exp) + builder (env, exp) = case exp of + ECase scrut alts -> + let constVal = getValue scrut env + known = isKnown constVal || Map.member scrut env + matchingAlts = [alt | alt@(Alt cpat name body) <- alts, match cpat constVal] + defaultAlts = [alt | alt@(Alt DefaultPat name body) <- alts] + -- HINT: use cpat as known value in the alternative ; bind cpat to val + altEnv cpat = env `mappend` unify env scrut (cPatToVal cpat) + in case (known, matchingAlts, defaultAlts) of + -- known scutinee, specific pattern + (True, [Alt cpat name body], _) -> (env,) <$> SBlockF (EBind (SReturn $ constVal) (cPatToAsPat cpat name) body) + + -- known scutinee, default pattern + (True, _, [Alt DefaultPat name body]) -> (env,) <$> SBlockF (EBind (SReturn $ Var scrut) (VarPat name) body) + + -- unknown scutinee + -- HINT: in each alternative set val value like it was matched + _ -> ECaseF scrut [(altEnv cpat, alt) | alt@(Alt cpat name _) <- alts] + + -- track values + EBind (SReturn val) bPat rightExp -> (env `mappend` unify env (bPat ^. _BPatVar) val,) <$> project exp + + _ -> (env,) <$> project exp + + unify :: Env -> Name -> Val -> Env + unify env var val = case val of + ConstTagNode{} -> Map.singleton var val + Unit -> Map.singleton var val -- HINT: default pattern (minor hack) + Var v -> Map.singleton var (getValue v env) + Lit{} -> mempty + _ -> error $ "ConstantPropagation/unify: unexpected value: " ++ show (val) -- TODO: PP + + isKnown :: Val -> Bool + isKnown = \case + ConstTagNode{} -> True + _ -> False + + match :: CPat -> Val -> Bool + match (NodePat tagA _) (ConstTagNode tagB _) = tagA == tagB + match _ _ = False + + getValue :: Name -> Env -> Val + getValue varName env = Map.findWithDefault (Var varName) varName env diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/CopyPropagation.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/CopyPropagation.hs index 22135fbe..6dfc675a 100644 --- a/grin/src/Transformations/ExtendedSyntax/Optimising/CopyPropagation.hs +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/CopyPropagation.hs @@ -1,7 +1,7 @@ {-# LANGUAGE LambdaCase, TupleSections, ViewPatterns #-} module Transformations.ExtendedSyntax.Optimising.CopyPropagation where -import Control.Monad.State +import Control.Monad.State.Strict import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/DeadDataElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/DeadDataElimination.hs new file mode 100644 index 00000000..1812130b --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/DeadDataElimination.hs @@ -0,0 +1,262 @@ +{-# LANGUAGE LambdaCase, RecordWildCards, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.DeadDataElimination where + +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Vector (Vector) +import qualified Data.Vector as Vec + +import Data.List +import Data.Maybe +import Data.Functor.Foldable as Foldable + +import Control.Monad +import Control.Monad.State.Strict +import Control.Monad.Trans.Except + +import Lens.Micro + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.Pretty +import Grin.ExtendedSyntax.TypeEnv + +import AbstractInterpretation.ExtendedSyntax.CreatedBy.Util +import AbstractInterpretation.ExtendedSyntax.CreatedBy.Result +import AbstractInterpretation.ExtendedSyntax.LiveVariable.Result + +import Transformations.ExtendedSyntax.Util +import Transformations.ExtendedSyntax.Names + + +-- TODO: make NameM local (it's only used once in ddeFromProducers) +-- (t,lv) -> t' +-- we deleted the dead fields from a node with tag t with liveness lv +-- then we introduced the new tag t' for this deleted node +type TagMapping = Map (Tag, Vector Bool) Tag +type Trf = ExceptT String (StateT TagMapping NameM) + +execTrf :: Exp -> Trf a -> Either String (a, ExpChanges) +execTrf e = moveChangedResult . evalNameM e . flip evalStateT mempty . runExceptT + where + moveChangedResult (x, b) = either Left (\r -> Right (r, b)) x + +getTag :: Tag -> Vector Bool -> Trf Tag +getTag t lv + | and lv = pure t +getTag t@(Tag ty n) lv = do + mt' <- gets $ Map.lookup (t,lv) + case mt' of + Just t' -> return t' + Nothing -> do + n' <- lift $ lift $ deriveNewName n + let t' = Tag ty n' + modify $ Map.insert (t,lv) t' + return t' + + +deadDataElimination :: LVAResult -> CByResult -> TypeEnv -> Exp -> Either String (Exp, ExpChanges) +deadDataElimination lvaResult cbyResult tyEnv e = execTrf e $ + ddeFromProducers lvaResult cbyResult tyEnv e >>= ddeFromConsumers cbyResult tyEnv + + +lookupNodeLivenessM :: Name -> Tag -> LVAResult -> Trf (Vector Bool) +lookupNodeLivenessM v t lvaResult = do + lvInfo <- lookupExcept (noLiveness v) v . _registerLv $ lvaResult + case lvInfo of + NodeSet taggedLiveness -> + _fields <$> lookupExcept (noLivenessTag v t) t taggedLiveness + _ -> throwE $ notANode v + where noLiveness v = noLivenessMsg ++ show (PP v) + noLivenessTag v t = noLivenessMsg ++ show (PP v) ++ " with tag " ++ show (PP t) + noLivenessMsg = "No liveness information present for variable " + notANode v = "Variable " ++ show (PP v) ++ " has non-node liveness information. " ++ + "Probable cause: Either lookupNodeLivenessM was called on a non-node variable, " ++ + "or the liveness information was never calculated for the variable " ++ + "(e.g.: it was inside a dead case alternative)." + +-- Global liveness is the accumulated liveness information about the producers +-- It represents the collective liveness of a producer group. +type GlobalLiveness = Map Name (Map Tag (Vector Bool)) + +{- + This should always get an active producer graph + Even if it does not, lookupNodeLivenessM will not be called on dead(1) variables, + because the connectProds set will be empty for such variables. + + (1) - Here "dead variable" means a variable that was not analyzed. + + NOTE: We will ignore undefined producers, since they should always be dead. +-} +calcGlobalLiveness :: LVAResult -> + ProducerGraph' -> + Trf GlobalLiveness +calcGlobalLiveness lvaResult (withoutUndefined -> prodGraph) = + mapWithDoubleKeyM' mergeLivenessExcept prodGraph where + + -- map using only the keys + mapWithDoubleKeyM' f = mapWithDoubleKeyM (\k1 k2 v -> f k1 k2) + + -- For producer p and tag t, it merges the liveness information of all fields + -- with the other producers sharing a consumer with p for tag t. + -- Every producer must have at least one connection for its own tag + -- with itself (reflexive closure). + -- NOTE: What if a ctor is applied to different number of arguments? + -- This can only happen at pattern matches, not at the time of construction. + -- So we do not have to worry about the liveness of those "extra" parameters. + -- They will always be at the last positions. + mergeLivenessExcept :: Name -> Tag -> Trf (Vector Bool) + mergeLivenessExcept prod tag = do + let ps = Set.toList connectedProds + when (null ps) (throwE $ noConnections prod tag) + ls <- mapM (\v -> lookupNodeLivenessM v tag lvaResult) ps + pure $ foldl1 (Vec.zipWith (||)) ls + + where + connectedProds :: Set Name + connectedProds = fromMaybe mempty + . Map.lookup tag + . fromMaybe mempty + . Map.lookup prod + $ prodGraph + + noConnections :: (Pretty a, Pretty b) => a -> b -> String + noConnections p t = "Producer " ++ show (PP p) ++ + " for tag " ++ show (PP t) ++ + " is not connected with any other producers" + +ddeFromConsumers :: CByResult -> TypeEnv -> (Exp, GlobalLiveness) -> Trf Exp +ddeFromConsumers cbyResult tyEnv (e, gblLiveness) = cataM alg e where + + alg :: ExpF Exp -> Trf Exp + alg = \case + ECaseF v alts -> do + alts' <- forM alts $ \case + Alt (NodePat t args) altName e -> do + (args',lv) <- deleteDeadFieldsM v t args + let deletedArgs = args \\ args' + e' <- bindToUndefineds tyEnv e deletedArgs + t' <- getTag t lv + pure $ Alt (NodePat t' args') altName e' + e -> pure e + pure $ ECase v alts' + + EBindF lhs (AsPat t args v) rhs -> do + (args',lv) <- deleteDeadFieldsM v t args + let deletedArgs = (args \\ args') + rhs' <- bindToUndefineds tyEnv rhs deletedArgs + t' <- getTag t lv + pure $ EBind lhs (AsPat t' args' v) rhs' + + e -> pure . embed $ e + + deleteDeadFieldsM :: Name -> Tag -> [a] -> Trf ([a], Vector Bool) + deleteDeadFieldsM v t args = do + gblLivenessVT <- lookupGlobalLivenessM v t + let args' = zipFilter args gblLivenessVT + liveness = Vec.fromList $ take (length args) gblLivenessVT + pure (args', liveness) + + -- Returns "all dead" if it cannot find the tag + -- This way it handles impossible case alternatives + -- NOTE: could also be solved by prior sparse case optimisation + lookupGlobalLivenessM :: Name -> Tag -> Trf [Bool] + lookupGlobalLivenessM v t = do + let pMap = _producerMap . _producers $ cbyResult + pSet <- _producerSet <$> lookupExcept (notFoundInPMap v) v pMap + flip catchE (const $ pure $ repeat False) $ do + ~(p:_) <- Set.toList <$> lookupExcept (notFoundInPSet t) t pSet + liveness <- lookupWithDoubleKeyExcept (notFoundLiveness p t) p t gblLiveness + pure $ Vec.toList liveness + +-- For each producer, it dummifies all locally unused fields. +-- If the field is dead for all other producers in the same group, +-- then it deletes that field. +-- Whenever it deletes a field, it makes a new entry into a table. +-- This table will be used to transform the consumers. +ddeFromProducers :: LVAResult -> CByResult -> TypeEnv -> Exp -> Trf (Exp, GlobalLiveness) +ddeFromProducers lvaResult cbyResult tyEnv e = (,) <$> cataM alg e <*> globalLivenessM where + + -- deleteing all globally unused fields + -- if the variable was not analyzed (has type T_Dead), it will be skipped + alg :: ExpF Exp -> Trf Exp + alg = \case + -- TODO: investigate as-pat case + e@(EBindF (SReturn (ConstTagNode t args)) bPat@(_bPatVar -> v) rhs) + | Just T_Dead <- tyEnv ^? variable . at v . _Just . _T_SimpleType + -> pure . embed $ e + -- TODO: investigate as-pat case + EBindF (SReturn (ConstTagNode t args)) bPat@(_bPatVar -> v) rhs -> do + globalLiveness <- globalLivenessM + nodeLiveness <- lookupNodeLivenessM v t lvaResult + globalNodeLiveness <- lookupWithDoubleKeyExcept (notFoundLiveness v t) v t globalLiveness + let onlyDummifiable = \locallyLive globallyLive -> not locallyLive && globallyLive + onlyDummifiables = Vec.zipWith onlyDummifiable nodeLiveness globalNodeLiveness + toBeDummified = zipFilter (zip [0..] args) (Vec.toList onlyDummifiables) + toBeDummifiedIxs = fst <$> toBeDummified + toBeDummifiedArgs = snd <$> toBeDummified + typedDummifiedArgs <- typedFreshNames toBeDummifiedArgs + newTag <- getTag t globalNodeLiveness -- could be the same as the old one + let argsVec = Vec.fromList args + indexedNewArgs = Vec.fromList . zip toBeDummifiedIxs . map fst $ typedDummifiedArgs + newArgs = Vec.toList $ Vec.update argsVec indexedNewArgs + liveNewArgs = zipFilter newArgs (Vec.toList globalNodeLiveness) + returnNewNode = SReturn (ConstTagNode newTag liveNewArgs) + pure $ typedDummifiedArgs `areBoundThen` EBind returnNewNode bPat rhs + e -> pure . embed $ e + + -- extracts the active producer grouping from the CByResult + -- if not present, it calculates it (so it will always work with only the active producers) + prodGraph :: ProducerGraph' + prodGraph = case _groupedProducers cbyResult of + All _ -> fromProducerGraph + . groupActiveProducers lvaResult + . _producers + $ cbyResult + Active activeProdGraph -> fromProducerGraph activeProdGraph + + globalLivenessM :: Trf GlobalLiveness + globalLivenessM = calcGlobalLiveness lvaResult prodGraph + + -- NOTE: uses tyEnv from outer scope + -- | Given a list of names, it looks up their types + -- and pairs them with fresh new names. + typedFreshNames :: [Name] -> Trf [(Name, Type)] + typedFreshNames ns = forM ns $ \v -> do + v' <- lift $ lift $ deriveNewName v + ty <- lookupExcept (notFoundInTyEnv v) v (_variable tyEnv) + let ty' = simplifyType ty + pure (v', ty') + + -- TODO: comment + -- | Constructs a binding sequence which first + -- binds the typed undefineds to the given names, + -- then returns a node with those arguments. + areBoundThen :: [(Name, Type)] -> Exp -> Exp + areBoundThen typedDummifiedArgs cont = + foldl rebindToUndefined cont typedDummifiedArgs where + + -- returnNewNode :: Exp + -- returnNewNode = SReturn (ConstTagNode tag allArgs) + + rebindToUndefined :: Exp -> (Name, Type) -> Exp + rebindToUndefined rhs (v, ty) = + EBind (SReturn (Undefined ty)) (VarPat v) rhs + +notFoundInPMap :: Pretty a => a -> String +notFoundInPMap v = notFoundIn "Variable" (PP v) "producer map" + +notFoundInPSet :: Pretty a => a -> String +notFoundInPSet t = notFoundIn "Tag" (PP t) "producer set" + +notFoundLiveness :: (Pretty a, Pretty b) => a -> b -> String +notFoundLiveness p t = "Producer " ++ show (PP p) ++ + " with tag " ++ show (PP t) ++ + " not found in global liveness map" + +notFoundInTyEnv :: Pretty a => a -> String +notFoundInTyEnv v = notFoundIn "Variable" (PP v) "type environment" + +notFoundInTySetFor :: (Pretty a, Pretty b) => a -> b -> String +notFoundInTySetFor t v = (notFoundIn "Tag" (PP t) "node type set") ++ " for variable " ++ show (PP v) diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/DeadParameterElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/DeadParameterElimination.hs new file mode 100644 index 00000000..cc5591f8 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/DeadParameterElimination.hs @@ -0,0 +1,53 @@ +{-# LANGUAGE LambdaCase, RecordWildCards #-} +module Transformations.ExtendedSyntax.Optimising.DeadParameterElimination where + +import Data.Set (Set) +import Data.Map (Map) +import Data.Vector (Vector) + +import qualified Data.Set as Set +import qualified Data.Map as Map +import qualified Data.Vector as Vec + +import Data.List + +import qualified Data.Foldable +import Data.Functor.Foldable as Foldable + +import Control.Monad.Trans.Except + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnvDefs +import Transformations.ExtendedSyntax.Util +import AbstractInterpretation.ExtendedSyntax.LiveVariable.Result as LVA + +type Trf = Except String + +runTrf :: Trf a -> Either String a +runTrf = runExcept + +-- P and F nodes are handled by Dead Data Elimination +deadParameterElimination :: LVAResult -> TypeEnv -> Exp -> Either String Exp +deadParameterElimination lvaResult tyEnv = runTrf . cataM alg where + alg :: ExpF Exp -> Trf Exp + alg = \case + DefF f args body -> do + liveArgs <- onlyLiveArgs f args + let deletedArgs = args \\ liveArgs + body' <- bindToUndefineds tyEnv body deletedArgs + return $ Def f liveArgs body' + SAppF f args -> do + liveArgs <- onlyLiveArgs f args + return $ SApp f liveArgs + e -> pure . embed $ e + + onlyLiveArgs :: Name -> [a] -> Trf [a] + onlyLiveArgs f args = do + argsLv <- lookupArgLivenessM f lvaResult + return $ zipFilter args (Vec.toList argsLv) + +lookupArgLivenessM :: Name -> LVAResult -> Trf (Vector Bool) +lookupArgLivenessM f LVAResult{..} = do + let funNotFound = "Function " ++ show f ++ " was not found in liveness analysis result" + (_,argLv) <- lookupExcept funNotFound f _functionLv + return $ Vec.map isLive argLv diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs new file mode 100644 index 00000000..3690d893 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxing.hs @@ -0,0 +1,187 @@ +{-# LANGUAGE LambdaCase, TupleSections, OverloadedStrings #-} +module Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing where + +import Data.Set (Set) +import Data.Vector (Vector) +import Data.Map.Strict (Map) +import Data.Function (fix) +import Data.Bifunctor (second) +import Data.Functor.Infix ((<$$>)) +import Data.Functor.Foldable as Foldable +import Data.Maybe (catMaybes, mapMaybe, isJust) + +import Lens.Micro.Platform + +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import qualified Data.Vector as Vector + +import Transformations.ExtendedSyntax.Util (anaM, apoM) +import Transformations.ExtendedSyntax.Names + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.Pretty + + +generalizedUnboxing :: TypeEnv -> Exp -> (Exp, ExpChanges) +generalizedUnboxing te exp = if (null funs) + then (exp, NoChange) + else second + (const NewNames) -- New functions are created, but NameM monad is not used + (evalNameM exp (transformCalls funs te =<< transformReturns funs te exp)) + where + funs = functionsToUnbox te exp + +-- TODO: Support tagless nodes. + +tailCalls :: Exp -> Maybe [Name] +tailCalls = cata collect where + collect :: ExpF (Maybe [Name]) -> Maybe [Name] + collect = \case + DefF _ _ result -> result + EBindF _ _ result -> result + ECaseF _ alts -> nonEmpty $ concat $ catMaybes alts + AltF _ _ result -> result + SAppF f _ -> Just [f] + e -> Nothing + +nonEmpty :: [a] -> Maybe [a] +nonEmpty [] = Nothing +nonEmpty xs = Just xs + +doesReturnAKnownProduct :: TypeEnv -> Name -> Bool +doesReturnAKnownProduct = isJust <$$> returnsAUniqueTag + +returnsAUniqueTag :: TypeEnv -> Name -> Maybe (Tag, Type) +returnsAUniqueTag te name = do + (tag, vs) <- te ^? function . at name . _Just . _1 . _T_NodeSet . to Map.toList . to singleton . _Just + typ <- singleton (Vector.toList vs) + pure (tag, T_SimpleType typ) + +singleton :: [a] -> Maybe a +singleton = \case + [] -> Nothing + [a] -> Just a + _ -> Nothing + +transitive :: (Ord a) => (a -> Set a) -> Set a -> Set a +transitive f res0 = + let res1 = res0 `Set.union` (Set.unions $ map f $ Set.toList res0) + in if res1 == res0 + then res0 + else transitive f res1 + +-- TODO: Remove the fix combinator, explore the function +-- dependency graph and rewrite disqualify steps based on that. +functionsToUnbox :: TypeEnv -> Exp -> Set Name +functionsToUnbox te (Program exts defs) = result where + funName (Def n _ _) = n + + tailCallsMap :: Map Name [Name] + tailCallsMap = Map.fromList $ mapMaybe (\e -> (,) (funName e) <$> tailCalls e) defs + + tranisitiveTailCalls :: Map Name (Set Name) + tranisitiveTailCalls = Map.fromList $ map (\k -> (k, transitive inTailCalls (Set.singleton k))) $ Map.keys tailCallsMap + where + inTailCalls :: Name -> Set Name + inTailCalls n = maybe mempty Set.fromList $ Map.lookup n tailCallsMap + + nonCandidateTailCallMap = Map.withoutKeys tranisitiveTailCalls result0 + candidateCalledByNonCandidate = (Set.unions $ Map.elems nonCandidateTailCallMap) `Set.intersection` result0 + result = result0 `Set.difference` candidateCalledByNonCandidate + + result0 = Set.fromList $ step initial + initial = map funName $ filter (doesReturnAKnownProduct te . funName) defs + disqualify candidates = filter + (\candidate -> case Map.lookup candidate tailCallsMap of + Nothing -> True + Just calls -> all (`elem` candidates) calls) + candidates + step = fix $ \rec x0 -> + let x1 = disqualify x0 in + if x0 == x1 + then x0 + else rec x1 + +updateTypeEnv :: Set Name -> TypeEnv -> TypeEnv +updateTypeEnv funs te = te & function %~ unboxFun + where + unboxFun = Map.fromList . map changeFun . Map.toList + changeFun (n, ts@(ret, params)) = + if Set.member n funs + then (,) (n <> ".unboxed") + $ maybe ts ((\t -> (t, params)) . T_SimpleType) $ + ret ^? _T_NodeSet + . to Map.elems + . to singleton + . _Just + . to Vector.toList + . to singleton + . _Just + else (n, ts) + +transformReturns :: Set Name -> TypeEnv -> Exp -> NameM Exp +transformReturns toUnbox te exp = apoM builder (Nothing, exp) where + builder :: (Maybe (Tag, Type), Exp) -> NameM (ExpF (Either Exp (Maybe (Tag, Type), Exp))) + builder (mTagType, exp0) = case exp0 of + Def name params body + | Set.member name toUnbox -> pure $ DefF name params (Right (returnsAUniqueTag te name, body)) + | otherwise -> pure $ DefF name params (Left body) + + -- Always skip the lhs of a bind. + EBind lhs pat rhs -> pure $ EBindF (Left lhs) pat (Right (mTagType, rhs)) + + -- Remove the tag from the value + SReturn (ConstTagNode tag [arg]) -> pure $ SReturnF (Var arg) + + -- Rewrite a node variable + simpleExp + -- fromJust works, as when we enter the processing of body of the + -- expression only happens with the provided tag. + | canUnbox simpleExp + , Just (tag, typ) <- mTagType + -> do + freshName <- deriveNewName $ "unboxed." <> (showTS $ PP tag) + asPatName <- deriveWildCard + pure . SBlockF . Left $ EBind simpleExp (AsPat tag [freshName] asPatName) (SReturn $ Var freshName) + + rest -> pure (Right . (,) mTagType <$> project rest) + + -- NOTE: SApp is handled by transformCalls + canUnbox :: SimpleExp -> Bool + canUnbox = \case + SApp n ps -> n `Set.notMember` toUnbox + SReturn{} -> True + SFetch{} -> True + _ -> False + +transformCalls :: Set Name -> TypeEnv -> Exp -> NameM Exp +transformCalls toUnbox typeEnv exp = anaM builderM (True, Nothing, exp) where + builderM :: (Bool, Maybe Name, Exp) -> NameM (ExpF (Bool, Maybe Name, Exp)) + + builderM (isRightExp, mDefName, e) = case e of + + Def name params body + -> pure $ DefF (if Set.member name toUnbox then name <> ".unboxed" else name) params (True, Just name, body) + + -- track the control flow + EBind lhs pat rhs -> pure $ EBindF (False, mDefName, lhs) pat (isRightExp, mDefName, rhs) + + SApp name params + | Set.member name toUnbox + , Just defName <- mDefName + , unboxedName <- name <> ".unboxed" + , Just (tag, fstType) <- returnsAUniqueTag typeEnv name + -> if Set.member defName toUnbox && isRightExp + + -- from candidate to candidate: tailcalls do not need a transform + then pure $ SAppF unboxedName params + + -- from outside to candidate + else do + freshName <- deriveNewName $ "unboxed." <> (showTS $ PP tag) + pure . SBlockF . (isRightExp, mDefName,) $ + EBind (SApp unboxedName params) (VarPat freshName) (SReturn $ ConstTagNode tag [freshName]) + + rest -> pure ((isRightExp, mDefName,) <$> project rest) diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/Inlining.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/Inlining.hs new file mode 100644 index 00000000..eca6f86e --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/Inlining.hs @@ -0,0 +1,122 @@ +{-# LANGUAGE LambdaCase, TupleSections, RecordWildCards, OverloadedStrings #-} +module Transformations.ExtendedSyntax.Optimising.Inlining where + + +import Data.Set (Set) +import Data.Map.Strict (Map) +import Data.Bifunctor (first) +import Data.Functor.Foldable as Foldable + +import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import qualified Data.Foldable + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Util +import Transformations.ExtendedSyntax.Names + +-- analysis + +data Stat + = Stat + { bindCount :: !Int + , functionCallCount :: !(Map Name Int) + } + +instance Semigroup Stat where (Stat i1 m1) <> (Stat i2 m2) = Stat (i1 + i2) (Map.unionWith (+) m1 m2) +instance Monoid Stat where mempty = Stat 0 mempty + +selectInlineSet :: Program -> Set Name +selectInlineSet prog@(Program exts defs) = inlineSet where + + (bindList, callTrees) = unzip + [ (Map.singleton name bindCount, (name, functionCallCount)) + | def@(Def name _ _) <- defs + , let Stat{..} = cata folder def + ] + + bindSequenceLimit = 100 + + -- TODO: limit inline overhead using CALL COUNT * SIZE < LIMIT + + callSet = Map.keysSet . Map.filter (== 1) . Map.unionsWith (+) $ map snd callTrees + bindSet = Map.keysSet . Map.filter (< bindSequenceLimit) $ mconcat bindList + candidateSet = mconcat [bindSet `Set.intersection` leafSet, callSet] + defCallTree = Map.fromList callTrees + leafSet = Set.fromList [name | (name, callMap) <- callTrees, Map.null callMap] + + -- keep only the leaves of the candidate call tree + inlineSet = Set.delete "grinMain" $ Data.Foldable.foldr stripCallers candidateSet candidateSet + + -- remove intermediate nodes from the call tree + stripCallers name set = set Set.\\ (Map.keysSet $ Map.findWithDefault mempty name defCallTree) + + + folder :: ExpF Stat -> Stat + folder = \case + EBindF left _ right + -> mconcat [left, right, Stat 1 mempty] + + SAppF name _ + | not (isExternalName exts name) + -> Stat 0 $ Map.singleton name 1 + + exp -> Data.Foldable.fold exp + +-- transformation + +-- TODO: add the cloned variables to the type env +-- QUESTION: apo OR ana ??? +inlining :: Set Name -> TypeEnv -> Program -> (Program, ExpChanges) +inlining functionsToInline typeEnv prog@(Program exts defs) = evalNameM prog $ apoM builder prog where + + defMap :: Map Name Def + defMap = Map.fromList [(name, def) | def@(Def name _ _) <- defs] + + builder :: Exp -> NameM (ExpF (Either Exp Exp)) + builder = \case + + -- HINT: do not touch functions marked to inline + Def name args body | Set.member name functionsToInline -> pure . DefF name args $ Left body + + -- HINT: bind argument values to function's new arguments and append the body with the fresh names + -- with this solution the name refreshing is just a name mapping and does not require a substitution map + SApp name args + | Set.member name functionsToInline + , Just def <- Map.lookup name defMap + -> do + freshDef <- refreshNames mempty def + let (Def _ argNames funBody, nameMap) = freshDef + let bind (n,v) e = EBind (SReturn v) (VarPat n) e + pure . SBlockF . Left $ foldr bind funBody . zip argNames . map Var $ args + + exp -> pure (Right <$> project exp) + +{- + - maintain type env + - test inlining + - test inline selection + - test inline: autoselection + inlining + +-} + +lateInlining :: TypeEnv -> Exp -> (Exp, ExpChanges) +lateInlining typeEnv prog = first (cleanup nameSet typeEnv) $ inlining nameSet typeEnv prog where + nameSet = selectInlineSet prog + +inlineEval :: TypeEnv -> Exp -> (Exp, ExpChanges) +inlineEval te = first (cleanup nameSet te) . inlining nameSet te where + nameSet = Set.fromList ["eval", "idr_{EVAL_0}"] + +inlineApply :: TypeEnv -> Exp -> (Exp, ExpChanges) +inlineApply te = first (cleanup nameSet te) . inlining nameSet te where + nameSet = Set.fromList ["apply", "idr_{APPLY_0}"] + +inlineBuiltins :: TypeEnv -> Exp -> (Exp, ExpChanges) +inlineBuiltins te = first (cleanup nameSet te) . inlining nameSet te where + nameSet = Set.fromList ["_rts_int_gt", "_rts_int_add", "_rts_int_print"] -- TODO: use proper selection + +cleanup :: Set Name -> TypeEnv -> Program -> Program +cleanup nameSet typeEnv (Program exts defs) = + Program exts [def | def@(Def name _ _) <- defs, Set.notMember name nameSet] diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/NonSharedElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/NonSharedElimination.hs new file mode 100644 index 00000000..51331e07 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/NonSharedElimination.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE LambdaCase, RecordWildCards #-} +module Transformations.ExtendedSyntax.Optimising.NonSharedElimination where + +{- +Remove the updates that update only non-shared locations. +-} + +import qualified Data.Set as Set +import Data.Functor.Foldable as Foldable + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv (ptrLocations) +import Grin.ExtendedSyntax.TypeCheck (typeEnvFromHPTResult) +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) +import AbstractInterpretation.ExtendedSyntax.Sharing.Result (SharingResult(..)) + + + +nonSharedElimination :: SharingResult -> Exp -> (Exp, ExpChanges) +nonSharedElimination SharingResult{..} exp = (exp', change) where + + exp' = cata skipUpdate exp + + change = if exp' /= exp then DeletedHeapOperation else NoChange + + tyEnv = either error id $ typeEnvFromHPTResult _hptResult + + -- Remove bind when the parameter points to non-shared locations only. + skipUpdate :: ExpF Exp -> Exp + skipUpdate = \case + EBindF (SUpdate p _) _ rhs + | all notShared . ptrLocations tyEnv $ p -> rhs + exp -> embed exp + + notShared :: Loc -> Bool + notShared l = not $ Set.member l _sharedLocs diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionElimination.hs new file mode 100644 index 00000000..2dee7729 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionElimination.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE LambdaCase, TupleSections, OverloadedStrings #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionElimination where + + +import Data.Map (Map) +import Data.Set (Set) +import Data.Functor.Foldable as Foldable + +import qualified Data.Map as Map +import qualified Data.Set as Set +import qualified Data.Foldable + +import Text.Printf + +import Grin.ExtendedSyntax.Grin + +simpleDeadFunctionElimination :: Program -> Program +simpleDeadFunctionElimination exp@(Program exts defs) = Program exts [def | def@(Def name _ _) <- defs, Set.member name liveDefs] where + defMap :: Map Name Def + defMap = Map.fromList [(name, def) | def@(Def name _ _) <- defs] + + lookupDef :: Name -> Maybe Def + lookupDef name = Map.lookup name defMap + + liveDefs :: Set Name + liveDefs = fst $ until (\(live, visited) -> live == visited) visit (Set.singleton "grinMain", mempty) + + visit :: (Set Name, Set Name) -> (Set Name, Set Name) + visit (live, visited) = (mappend live seen, mappend visited toVisit) where + toVisit = Set.difference live visited + seen = foldMap (maybe mempty (cata collect) . lookupDef) toVisit + + collect :: ExpF (Set Name) -> Set Name + collect = \case + SAppF name _ | not (isExternalName exts name) -> Set.singleton name + exp -> Data.Foldable.fold exp diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterElimination.hs new file mode 100644 index 00000000..90e1eb1b --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterElimination.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE LambdaCase, TupleSections #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadParameterElimination where + +import Data.Set (Set) +import Data.Map (Map) +import Data.Maybe (mapMaybe) +import qualified Data.Set as Set +import qualified Data.Map as Map + +import Data.Functor.Foldable as Foldable +import qualified Data.Foldable + +import Grin.ExtendedSyntax.Grin +import Transformations.ExtendedSyntax.Util + +collectUsedNames :: Exp -> Set Name +collectUsedNames = cata folder where + folder exp = foldNameUseExpF Set.singleton exp `mappend` Data.Foldable.fold exp + +simpleDeadParameterElimination :: Program -> Program +simpleDeadParameterElimination prog@(Program exts defs) = ana builder prog where + deadArgMap :: Map Name (Set Int) + deadArgMap = mconcat $ mapMaybe deadArgsInDef defs + + deadArgsInDef :: Def -> Maybe (Map Name (Set Int)) + deadArgsInDef def@(Def name args _) + | usedNames <- collectUsedNames def + , deadArgIndices <- Set.fromList . map fst . filter (flip Set.notMember usedNames . snd) $ zip [0..] args + = if null deadArgIndices + then Nothing + else Just $ Map.singleton name deadArgIndices + + removeDead :: Set Int -> [a] -> [a] + removeDead dead args = [arg | (idx, arg) <- zip [0..] args, Set.notMember idx dead] + + builder :: Exp -> ExpF Exp + builder e = case mapValsExp pruneVal e of + Def name args body + | Just dead <- Map.lookup name deadArgMap + -> DefF name (removeDead dead args) body + + SApp name args + | Just dead <- Map.lookup name deadArgMap + -> SAppF name (removeDead dead args) + + EBind leftExp (AsPat tag args var) rightExp + | Tag kind tagName <- tag + , isPFtag kind + , Just deadIxs <- Map.lookup tagName deadArgMap + -> EBindF leftExp (AsPat tag (removeDead deadIxs args) var) rightExp + + Alt cpat@NodePat{} altName body + -> AltF (pruneCPat cpat) altName body + + exp -> project exp + + pruneVal :: Val -> Val + pruneVal = \case + ConstTagNode tag@(Tag kind name) args + | isPFtag kind + , Just dead <- Map.lookup name deadArgMap + -> ConstTagNode tag (removeDead dead args) + val -> val + + pruneCPat :: CPat -> CPat + pruneCPat = \case + NodePat tag@(Tag kind name) vars + | isPFtag kind + , Just deadIxs <- Map.lookup name deadArgMap + -> NodePat tag (removeDead deadIxs vars) + cpat -> cpat + + isPFtag :: TagType -> Bool + isPFtag = \case + F{} -> True + P{} -> True + _ -> False diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableElimination.hs new file mode 100644 index 00000000..93722b8b --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableElimination.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE LambdaCase #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadVariableElimination where + +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Monoid + +import Data.Functor.Foldable as Foldable +import qualified Data.Foldable + +import Lens.Micro.Platform + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.EffectMap +import Transformations.ExtendedSyntax.Util + + +-- TODO: consult EffectMap for side-effects +-- QUESTION: should SDVE use any interprocedural information? +simpleDeadVariableElimination :: EffectMap -> Exp -> Exp +simpleDeadVariableElimination effMap e = cata folder e ^. _1 where + + effectfulExternals :: Set Name + effectfulExternals = case e of + Program es _ -> Set.fromList $ map eName $ filter eEffectful es + _ -> Set.empty + + folder :: ExpF (Exp, Set Name, Bool) -> (Exp, Set Name, Bool) + folder = \case + + exp@(EBindF (left, _, True) bPat right) -> embedExp exp + exp@(EBindF (left, _, _) bPat right@(_, rightRef, _)) + | vars <- foldNames Set.singleton bPat -- if all the variables + , all (flip Set.notMember rightRef) vars -- are not referred + -> case left of + SBlock{} -> embedExp exp + _ -> right + + exp@(SAppF name _) -> + embedExp exp & _3 .~ (hasPossibleSideEffect name effMap || Set.member name effectfulExternals) + + exp -> embedExp exp + where + embedExp :: ExpF (Exp, Set Name, Bool) -> (Exp, Set Name, Bool) + embedExp exp0 = + ( embed (view _1 <$> exp0) + , foldNameUseExpF Set.singleton exp0 `mappend` Data.Foldable.fold (view _2 <$> exp0) + , getAny $ Data.Foldable.fold (view (_3 . to Any) <$> exp0) + ) diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisation.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisation.hs new file mode 100644 index 00000000..9e1c5022 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisation.hs @@ -0,0 +1,45 @@ +{-# LANGUAGE LambdaCase, TupleSections, RecordWildCards #-} +module Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisation where + +import qualified Data.Map as Map +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Functor.Foldable as Foldable + +import Control.Monad.Trans.Except + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.Pretty +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Util + +sparseCaseOptimisation :: TypeEnv -> Exp -> Either String Exp +sparseCaseOptimisation TypeEnv{..} = runExcept . anaM builder where + builder :: Exp -> Except String (ExpF Exp) + builder = \case + ECase scrut alts -> do + scrutType <- lookupExcept (notInTyEnv scrut) scrut _variable + let alts' = filterAlts scrutType alts + pure $ ECaseF scrut alts' + exp -> pure . project $ exp + + notInTyEnv v = "SCO: Variable " ++ show (PP v) ++ " not found in type env" + + filterAlts :: Type -> [Exp] -> [Exp] + filterAlts scrutTy alts = + [ alt + | alt@(Alt cpat _name _body) <- alts + , possible scrutTy allPatTags cpat + ] where allPatTags = Set.fromList [tag | Alt (NodePat tag _) _name _body <- alts] + + possible :: Type -> Set Tag -> CPat -> Bool + possible (T_NodeSet nodeSet) allPatTags cpat = case cpat of + NodePat tag _args -> Map.member tag nodeSet + -- HINT: the default case is redundant if normal cases fully cover the domain + DefaultPat -> not $ null (Set.difference (Map.keysSet nodeSet) allPatTags) + _ -> False + + possible ty@T_SimpleType{} _ cpat = case cpat of + LitPat lit -> ty == typeOfLit lit + DefaultPat -> True -- HINT: the value domain is unknown, it is not possible to prove if it overlaps or it is fully covered + _ -> False diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/TrivialCaseElimination.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/TrivialCaseElimination.hs index 166d8fd9..fbd07b6b 100644 --- a/grin/src/Transformations/ExtendedSyntax/Optimising/TrivialCaseElimination.hs +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/TrivialCaseElimination.hs @@ -10,5 +10,5 @@ trivialCaseElimination = ana builder where builder :: Exp -> ExpF Exp builder = \case ECase scrut [Alt DefaultPat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (VarPat altName) body - ECase scrut [Alt cpat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (cPatToAsPat altName cpat) body + ECase scrut [Alt cpat altName body] -> SBlockF $ EBind (SReturn (Var scrut)) (cPatToAsPat cpat altName) body exp -> project exp diff --git a/grin/src/Transformations/ExtendedSyntax/Util.hs b/grin/src/Transformations/ExtendedSyntax/Util.hs index 785280e7..6dc395f0 100644 --- a/grin/src/Transformations/ExtendedSyntax/Util.hs +++ b/grin/src/Transformations/ExtendedSyntax/Util.hs @@ -108,19 +108,19 @@ mapNameUseExp f = \case subst :: Ord a => Map a a -> a -> a subst env x = Map.findWithDefault x x env --- substitute all @Names@s in an @Exp@ +-- substitute all @Names@s in an @Exp@ (non-recursive) substVarRefExp :: Map Name Name -> Exp -> Exp substVarRefExp env = mapNameUseExp (subst env) --- substitute all @Names@s in a @Val@ +-- substitute all @Names@s in a @Val@ (non-recursive) substNamesVal :: Map Name Name -> Val -> Val substNamesVal env = mapNamesVal (subst env) --- specialized version of @subst@ to @Val@s +-- specialized version of @subst@ to @Val@s (non-recursive) substValsVal :: Map Val Val -> Val -> Val substValsVal env = subst env --- substitute all @Val@s in an @Exp@ +-- substitute all @Val@s in an @Exp@ (non-recursive) substVals :: Map Val Val -> Exp -> Exp substVals env = mapValsExp (subst env) @@ -130,9 +130,9 @@ cPatToVal = \case LitPat lit -> Lit lit DefaultPat -> Unit -cPatToAsPat :: Name -> CPat -> BPat -cPatToAsPat name (NodePat tag args) = AsPat tag args name -cPatToAsPat _ cPat = error $ "cPatToAsPat: cannot convert to as-pattern: " ++ show (PP cPat) +cPatToAsPat :: CPat -> Name -> BPat +cPatToAsPat (NodePat tag args) name = AsPat tag args name +cPatToAsPat cPat _ = error $ "cPatToAsPat: cannot convert to as-pattern: " ++ show (PP cPat) -- monadic recursion schemes -- see: https://jtobin.io/monadic-recursion-schemes diff --git a/grin/test-data/ExtendedSyntax/dead-data-elimination/length_after.grin b/grin/test-data/ExtendedSyntax/dead-data-elimination/length_after.grin index 4465d2d9..331dd565 100644 --- a/grin/test-data/ExtendedSyntax/dead-data-elimination/length_after.grin +++ b/grin/test-data/ExtendedSyntax/dead-data-elimination/length_after.grin @@ -1,54 +1,69 @@ -grinMain = n1 <- pure (CInt 1) - t1 <- store n1 - n2 <- pure (CInt 10000) - t2 <- store n2 - n3 <- pure (Fupto t1 t2) - t3 <- store n3 - n4 <- pure (Flength t3) - t4 <- store n4 - n5 <- eval t4 - (CInt r') <- pure n5 - _prim_int_print r' +grinMain = + k0 <- pure 1 + n1 <- pure (CInt k0) + t1 <- store n1 + k1 <- pure 10000 + n2 <- pure (CInt k1) + t2 <- store n2 + n3 <- pure (Fupto t1 t2) + t3 <- store n3 + n4 <- pure (Flength t3) + t4 <- store n4 + n5 <- eval t4 + (CInt r') @ _1 <- pure n5 + _prim_int_print r' -upto m n = n6 <- eval m - (CInt m') <- pure n6 - n7 <- eval n - (CInt n') <- pure n7 - b' <- _prim_int_gt m' n' - if b' then - n8 <- pure (CNil) - pure n8 - else - m1' <- _prim_int_add m' 1 - n9 <- pure (CInt m1') - m1 <- store n9 - n10 <- pure (Fupto m1 n) - p <- store n10 - n11 <- pure (CCons.0 p) - pure n11 +upto m n = + n6 <- eval m + (CInt m') @ _2 <- pure n6 + n7 <- eval n + (CInt n') @ _3 <- pure n7 + b' <- _prim_int_gt m' n' + case b' of + #True @ altT -> + n8 <- pure (CNil) + pure n8 + #False @ altF -> + k2 <- pure 1 + m1' <- _prim_int_add m' k2 + n9 <- pure (CInt m1') + m1 <- store n9 + n10 <- pure (Fupto m1 n) + p <- store n10 + n11 <- pure (CCons.0 p) + pure n11 -length l = l2 <- eval l - case l2 of - (CNil) -> - n12 <- pure (CInt 0) - pure n12 - (CCons.0 xs) -> - x <- pure (#undefined :: #ptr) - n13 <- length xs - (CInt l') <- pure n13 - len <- _prim_int_add l' 1 - n14 <- pure (CInt len) - pure n14 +length l = + l2 <- eval l + case l2 of + (CNil) @ alt1 -> + k3 <- pure 0 + n12 <- pure (CInt k3) + pure n12 + (CCons.0 xs) @ alt2 -> + x <- pure (#undefined :: #ptr) + n13 <- length xs + (CInt l') @ _4 <- pure n13 + k4 <- pure 1 + len <- _prim_int_add l' k4 + n14 <- pure (CInt len) + pure n14 -eval q = v <- fetch q - case v of - (CInt x'1) -> pure v - (CNil) -> pure v - (CCons.0 ys) -> y <- pure (#undefined :: #ptr) - pure v - (Fupto a b) -> w <- upto a b - update q w - pure w - (Flength c) -> z <- length c - update q z - pure z +eval q = + v <- fetch q + case v of + (CInt x'1) @ alt3 -> + pure v + (CNil) @ alt4 -> + pure v + (CCons.0 ys) @ alt5 -> + y <- pure (#undefined :: #ptr) + pure v + (Fupto a b) @ alt6 -> + w <- upto a b + _5 <- update q w + pure w + (Flength c) @ alt7 -> + z <- length c + _6 <- update q z + pure z diff --git a/grin/test-data/ExtendedSyntax/dead-data-elimination/length_before.grin b/grin/test-data/ExtendedSyntax/dead-data-elimination/length_before.grin index dc3eff64..87a95516 100644 --- a/grin/test-data/ExtendedSyntax/dead-data-elimination/length_before.grin +++ b/grin/test-data/ExtendedSyntax/dead-data-elimination/length_before.grin @@ -1,50 +1,67 @@ -grinMain = n1 <- pure (CInt 1) - t1 <- store n1 - n2 <- pure (CInt 10000) - t2 <- store n2 - n3 <- pure (Fupto t1 t2) - t3 <- store n3 - n4 <- pure (Flength t3) - t4 <- store n4 - n5 <- eval t4 - (CInt r') <- pure n5 - _prim_int_print r' +grinMain = + k0 <- pure 1 + n1 <- pure (CInt k0) + t1 <- store n1 + k1 <- pure 10000 + n2 <- pure (CInt k1) + t2 <- store n2 + n3 <- pure (Fupto t1 t2) + t3 <- store n3 + n4 <- pure (Flength t3) + t4 <- store n4 + n5 <- eval t4 + (CInt r') @ _1 <- pure n5 + _prim_int_print r' -upto m n = n6 <- eval m - (CInt m') <- pure n6 - n7 <- eval n - (CInt n') <- pure n7 - b' <- _prim_int_gt m' n' - if b' then - n8 <- pure (CNil) - pure n8 - else - m1' <- _prim_int_add m' 1 - n9 <- pure (CInt m1') - m1 <- store n9 - n10 <- pure (Fupto m1 n) - p <- store n10 - n11 <- pure (CCons m p) - pure n11 +upto m n = + n6 <- eval m + (CInt m') @ _2 <- pure n6 + n7 <- eval n + (CInt n') @ _3 <- pure n7 + b' <- _prim_int_gt m' n' + case b' of + #True @ altT -> + n8 <- pure (CNil) + pure n8 + #False @ altF -> + k2 <- pure 1 + m1' <- _prim_int_add m' k2 + n9 <- pure (CInt m1') + m1 <- store n9 + n10 <- pure (Fupto m1 n) + p <- store n10 + n11 <- pure (CCons m p) + pure n11 -length l = l2 <- eval l - case l2 of - (CNil) -> n12 <- pure (CInt 0) - pure n12 - (CCons x xs) -> n13 <- length xs - (CInt l') <- pure n13 - len <- _prim_int_add l' 1 - n14 <- pure (CInt len) - pure n14 +length l = + l2 <- eval l + case l2 of + (CNil) @ alt1 -> + k3 <- pure 0 + n12 <- pure (CInt k3) + pure n12 + (CCons x xs) @ alt2 -> + n13 <- length xs + (CInt l') @ _4<- pure n13 + k4 <- pure 1 + len <- _prim_int_add l' k4 + n14 <- pure (CInt len) + pure n14 -eval q = v <- fetch q - case v of - (CInt x'1) -> pure v - (CNil) -> pure v - (CCons y ys) -> pure v - (Fupto a b) -> w <- upto a b - update q w - pure w - (Flength c) -> z <- length c - update q z - pure z \ No newline at end of file +eval q = + v <- fetch q + case v of + (CInt x'1) @ alt3 -> + pure v + (CNil) @ alt4 -> + pure v + (CCons y ys) @ alt5 -> + pure v + (Fupto a b) @ alt6 -> + w <- upto a b + _5 <- update q w + pure w + (Flength c) @ alt7 -> + z <- length c + _6 <- update q z + pure z diff --git a/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_after.grin b/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_after.grin index 9cd3df82..74468f0d 100644 --- a/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_after.grin +++ b/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_after.grin @@ -1,7 +1,8 @@ grinMain = - a0 <- pure (CInt 5) - a1 <- pure (CInt 5) - a2 <- pure (CInt 5) + k0 <- pure 0 + a0 <- pure (CInt k0) + a1 <- pure (CInt k0) + a2 <- pure (CInt k0) p0 <- store a0 p1 <- store a1 p2 <- store a2 @@ -33,13 +34,13 @@ foo x0 y0 z0 = -- apply always gets the function node in whnf apply pf cur = case pf of - (P3foo) -> + (P3foo) @ alt1 -> n0 <- pure (P2foo cur) pure n0 - (P2foo v0) -> + (P2foo v0) @ alt2 -> n1 <- pure (P1foo v0 cur) pure n1 - (P1foo v1 v2) -> + (P1foo v1 v2) @ alt3 -> n2 <- foo v1 v2 cur pure n2 @@ -50,26 +51,26 @@ ap f x = eval p = v <- fetch p case v of - (CInt n) -> pure v + (CInt n) @ alt4 -> pure v - (P3foo) -> pure v - (P2foo v3) -> pure v - (P1foo v4 v5) -> pure v + (P3foo) @ alt5 -> pure v + (P2foo v3) @ alt6 -> pure v + (P1foo v4 v5) @ alt7 -> pure v - (Ffoo.0) -> + (Ffoo.0) @ alt8 -> b2 <- pure (#undefined :: T_Dead) b1 <- pure (#undefined :: T_Dead) b0 <- pure (#undefined :: T_Dead) w0 <- foo b0 b1 b2 - update p w0 + _1 <- update p w0 pure w0 - (Fapply.0) -> + (Fapply.0) @ alt9 -> y <- pure (#undefined :: T_Dead) g <- pure (#undefined :: T_Dead) w1 <- apply g y - update p w1 + _2 <- update p w1 pure w1 - (Fap h z) -> + (Fap h z) @ alt10 -> w2 <- ap h z - update p w2 + _3 <- update p w2 pure w2 diff --git a/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_before.grin b/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_before.grin index dd7a21ea..24ae3bac 100644 --- a/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_before.grin +++ b/grin/test-data/ExtendedSyntax/dead-data-elimination/pnode_before.grin @@ -13,9 +13,10 @@ -} grinMain = - a0 <- pure (CInt 5) - a1 <- pure (CInt 5) - a2 <- pure (CInt 5) + k0 <- pure 0 + a0 <- pure (CInt k0) + a1 <- pure (CInt k0) + a2 <- pure (CInt k0) p0 <- store a0 p1 <- store a1 p2 <- store a2 @@ -47,13 +48,13 @@ foo x0 y0 z0 = -- apply always gets the function node in whnf apply pf cur = case pf of - (P3foo) -> + (P3foo) @ alt1 -> n0 <- pure (P2foo cur) pure n0 - (P2foo v0) -> + (P2foo v0) @ alt2 -> n1 <- pure (P1foo v0 cur) pure n1 - (P1foo v1 v2) -> + (P1foo v1 v2) @ alt3 -> n2 <- foo v1 v2 cur pure n2 @@ -64,21 +65,21 @@ ap f x = eval p = v <- fetch p case v of - (CInt n) -> pure v + (CInt n) @ alt4 -> pure v - (P3foo) -> pure v - (P2foo v3) -> pure v - (P1foo v4 v5) -> pure v + (P3foo) @ alt5 -> pure v + (P2foo v3) @ alt6 -> pure v + (P1foo v4 v5) @ alt7 -> pure v - (Ffoo b0 b1 b2) -> + (Ffoo b0 b1 b2) @ alt8 -> w0 <- foo b0 b1 b2 - update p w0 + _1 <- update p w0 pure w0 - (Fapply g y) -> + (Fapply g y) @ alt9 -> w1 <- apply g y - update p w1 + _2 <- update p w1 pure w1 - (Fap h z) -> + (Fap h z) @ alt10 -> w2 <- ap h z - update p w2 + _3 <- update p w2 pure w2 diff --git a/grin/test/AbstractInterpretation/ExtendedSyntax/SharingSpec.hs b/grin/test/AbstractInterpretation/ExtendedSyntax/SharingSpec.hs index 7621b897..c90bb033 100644 --- a/grin/test/AbstractInterpretation/ExtendedSyntax/SharingSpec.hs +++ b/grin/test/AbstractInterpretation/ExtendedSyntax/SharingSpec.hs @@ -84,7 +84,7 @@ spec = describe "Sharing analysis" $ do l1 <- store two (CTwo l2)@_1 <- fetch l1 _2 <- fetch l2 - _2 <- fetch l2 + _3 <- fetch l2 pure () |] let result = calcSharedLocations code diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/ArityRaisingSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/ArityRaisingSpec.hs new file mode 100644 index 00000000..fba85c1f --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/ArityRaisingSpec.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes #-} +module Transformations.ExtendedSyntax.Optimising.ArityRaisingSpec where + +import Transformations.ExtendedSyntax.Optimising.ArityRaising + +import Data.Monoid +import Control.Arrow +import qualified Data.Map.Strict as Map +import qualified Data.Vector as Vector + +import Test.Hspec + +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.TypeCheck +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "split_undefined" $ do + let tyEnv = inferTypeEnv testProgBefore + arityRaising 0 tyEnv testProgBefore `sameAs` (testProgAfter, NewNames) + +testProgBefore :: Exp +testProgBefore = [prog| +grinMain = + k0 <- pure 0 + v.0 <- pure (CInt k0) + p1 <- store v.0 + k1 <- pure 1 + v.1 <- pure (CInt k1) + p2 <- store v.1 + k2 <- pure 1000 + v.2 <- pure (CInt k2) + p3 <- store v.2 + v.3 <- pure (Fupto p2 p3) + p4 <- store v.3 + v.4 <- pure (Fsum p1 p4) + p5 <- store v.4 + v.5 <- fetch p5 + (Fsum p15 p16) @ _0 <- pure v.5 + n13' <- sum $ p15 p16 + _prim_int_print $ n13' + +sum p10 p11 = + v.6 <- fetch p11 + (Fupto p17 p18) @ _1 <- pure v.6 + v.7 <- fetch p17 + (CInt n2') @ _2 <- pure v.7 + v.8 <- fetch p18 + (CInt n3') @ _3 <- pure v.8 + b1' <- _prim_int_gt $ n2' n3' + case b1' of + #True @ alt1 -> + v.9 <- pure (CNil) + case v.9 of + (CNil) @ alt11 -> + v.10 <- fetch p10 + (CInt n14') @ _4 <- pure v.10 + pure n14' + (CCons.0) @ alt12 -> + ud0 <- pure (#undefined :: T_Dead) + ud1 <- pure (#undefined :: T_Dead) + sum $ ud0 ud1 + #False @ alt2 -> + k3 <- pure 1 + n4' <- _prim_int_add $ n2' k3 + v.14 <- pure (CInt n4') + p8 <- store v.14 + v.15 <- pure (Fupto p8 p18) + p9 <- store v.15 + v.16 <- pure (CCons p17 p9) + case v.16 of + (CNil) @ alt21 -> + pure (#undefined :: T_Dead) + (CCons p12_2 p13_2) @ alt22 -> + v.18 <- fetch p10 + (CInt n5'_2) @ _5 <- pure v.18 + v.19 <- fetch p12_2 + (CInt n6'_2) @ _6 <- pure v.19 + n7'_2 <- _prim_int_add $ n5'_2 n6'_2 + v.20 <- pure (CInt n7'_2) + p14_2 <- store v.20 + sum $ p14_2 p13_2 +|] + +testProgAfter :: Exp +testProgAfter = [prog| +grinMain = + k0 <- pure 0 + v.0 <- pure (CInt k0) + p1 <- store v.0 + k1 <- pure 1 + v.1 <- pure (CInt k1) + p2 <- store v.1 + k2 <- pure 1000 + v.2 <- pure (CInt k2) + p3 <- store v.2 + v.3 <- pure (Fupto p2 p3) + p4 <- store v.3 + v.4 <- pure (Fsum p1 p4) + p5 <- store v.4 + v.5 <- fetch p5 + (Fsum p15 p16) @ _0 <- pure v.5 + n13' <- do + (CInt p15.0.0.arity.1) @ _7 <- fetch p15 + (Fupto p16.0.0.arity.1 p16.0.0.arity.2) @ _8 <- fetch p16 + sum $ p15.0.0.arity.1 p16.0.0.arity.1 p16.0.0.arity.2 + _prim_int_print $ n13' + +sum p10.0.arity.1 p11.0.arity.1 p11.0.arity.2 = + v.6 <- pure (Fupto p11.0.arity.1 p11.0.arity.2) + (Fupto p17 p18) @ _1 <- pure v.6 + v.7 <- fetch p17 + (CInt n2') @ _2 <- pure v.7 + v.8 <- fetch p18 + (CInt n3') @ _3 <- pure v.8 + b1' <- _prim_int_gt $ n2' n3' + case b1' of + #True @ alt1 -> + v.9 <- pure (CNil) + case v.9 of + (CNil) @ alt11 -> + v.10 <- pure (CInt p10.0.arity.1) + (CInt n14') @ _4 <- pure v.10 + pure n14' + (CCons.0) @ alt12 -> + ud0 <- pure (#undefined :: T_Dead) + ud1 <- pure (#undefined :: T_Dead) + do + (CInt ud0.0.1.arity.1) @ _9 <- fetch ud0 + (Fupto ud1.0.1.arity.1 ud1.0.1.arity.2) @ _10 <- fetch ud1 + sum $ ud0.0.1.arity.1 ud1.0.1.arity.1 ud1.0.1.arity.2 + + #False @ alt2 -> + k3 <- pure 1 + n4' <- _prim_int_add $ n2' k3 + v.14 <- pure (CInt n4') + p8 <- store v.14 + v.15 <- pure (Fupto p8 p18) + p9 <- store v.15 + v.16 <- pure (CCons p17 p9) + case v.16 of + (CNil) @ alt21 -> + pure (#undefined :: T_Dead) + (CCons p12_2 p13_2) @ alt22 -> + v.18 <- pure (CInt p10.0.arity.1) + (CInt n5'_2) @ _5 <- pure v.18 + v.19 <- fetch p12_2 + (CInt n6'_2) @ _6 <- pure v.19 + n7'_2 <- _prim_int_add $ n5'_2 n6'_2 + v.20 <- pure (CInt n7'_2) + p14_2 <- store v.20 + do + (CInt p14_2.0.2.arity.1) @ _11 <- fetch p14_2 + (Fupto p13_2.0.2.arity.1 p13_2.0.2.arity.2) @ _12 <- fetch p13_2 + sum $ p14_2.0.2.arity.1 p13_2.0.2.arity.1 p13_2.0.2.arity.2 +|] diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagationSpec.hs new file mode 100644 index 00000000..f5f53709 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseCopyPropagationSpec.hs @@ -0,0 +1,335 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes #-} +module Transformations.ExtendedSyntax.Optimising.CaseCopyPropagationSpec where + +import Transformations.ExtendedSyntax.Optimising.CaseCopyPropagation + +import Data.Monoid + +import Test.Hspec + +import Test.ExtendedSyntax.New.Test hiding (newVar) +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeEnv +import Grin.ExtendedSyntax.Pretty +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +ctxs :: [TestExpContext] +ctxs = + [ emptyCtx + , lastBindR + , firstAlt + , middleAlt + , lastAlt + ] + + +spec :: Spec +spec = testExprContextIn ctxs $ \ctx -> do + + it "Example from Figure 4.26" $ do + let teBefore = create $ + (newVar "z'" int64_t) <> + (newVar "y'" int64_t) <> + (newVar "x'" int64_t) + let before = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CInt z') + (CInt x') @ alt2 -> + pure (CInt x') + pure m0 + |] + + let teAfter = extend teBefore $ + newVar "v'" int64_t + let after = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + ccp.0 <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure y' + (Fbar b) @ alt1 -> + z' <- bar b + pure z' + (CInt x') @ alt2 -> + pure x' + pure (CInt ccp.0) + pure m0 + |] + -- TODO: Inspect type env + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teAfter, after))), NewNames) + --(snd (ctx (teBefore, before))) `sameAs` (snd (ctx (teAfter, after))) + + it "One node has no Int tagged value" $ do + let typeEnv = emptyTypeEnv + let teBefore = create $ + (newVar "z'" float_t) <> + (newVar "y'" int64_t) <> + (newVar "x'" int64_t) + let before = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CFloat z') + (CInt x') @ alt2 -> + pure (CInt x') + pure m0 + |] + let after = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CFloat z') + (CInt x') @ alt2 -> + pure (CInt x') + pure m0 + |] + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teBefore, after))), NoChange) + + it "Embedded good case" $ do + -- pendingWith "doesn't unbox outer case" + let teBefore = create $ + (newVar "z'" int64_t) <> + (newVar "y'" int64_t) <> + (newVar "x'" int64_t) <> + (newVar "z1'" int64_t) <> + (newVar "y1'" int64_t) <> + (newVar "x1'" int64_t) + let before = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CInt z') + (CInt x') @ alt2 -> + u1 <- case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure (CInt y1') + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure (CInt z1') + (CInt x1') @ alt22 -> + pure (CInt x1') + pure (CInt x') + pure m0 + |] + let teAfter = extend teBefore $ + newVar "v'" int64_t <> + newVar "v1'" int64_t + let after = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + ccp.1 <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure y' + (Fbar b) @ alt1 -> + z' <- bar b + pure z' + (CInt x') @ alt2 -> + u1 <- do + ccp.0 <- case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure y1' + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure z1' + (CInt x1') @ alt22 -> + pure x1' + pure (CInt ccp.0) + pure x' + pure (CInt ccp.1) + pure m0 + |] + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teAfter, after))), NewNames) + + it "Embedded bad case" $ do + let teBefore = create $ + newVar "z'" int64_t <> + newVar "y'" int64_t <> + newVar "x'" int64_t <> + newVar "y1'" int64_t <> + newVar "z1'" float_t <> + newVar "x1'" int64_t + let before = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CInt z') + (CInt x') @ alt2 -> + u1 <- do + case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure (CInt y1') + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure (CFloat z1') + (CInt x1') @ alt22 -> + pure (CInt x1') + pure (CInt x') + pure m0 + |] + let teAfter = extend teBefore $ + newVar "v'" int64_t + let after = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + ccp.0 <- case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure y' + (Fbar b) @ alt1 -> + z' <- bar b + pure z' + (CInt x') @ alt2 -> + u1 <- do + case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure (CInt y1') + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure (CFloat z1') + (CInt x1') @ alt22 -> + pure (CInt x1') + pure x' + pure (CInt ccp.0) + pure m0 + |] + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teAfter, after))), NewNames) + + it "Leave the outer, transform the inner" $ do + let teBefore = create $ + newVar "z'" float_t <> + newVar "y'" int64_t <> + newVar "x'" int64_t <> + newVar "y1'" int64_t <> + newVar "z1'" int64_t <> + newVar "x1'" int64_t + let before = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CFloat z') + (CInt x') @ alt2 -> + u1 <- case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure (CInt y1') + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure (CInt z1') + (CInt x1') @ alt22 -> + pure (CInt x1') + pure (CInt x') + pure m0 + |] + let teAfter = extend teBefore $ + newVar "v1'" int64_t + let after = [expr| + n0 <- pure (CNone) + m0 <- store n0 + u <- do + case v of + (Ffoo a) @ alt0 -> + y' <- foo a + pure (CInt y') + (Fbar b) @ alt1 -> + z' <- bar b + pure (CFloat z') + (CInt x') @ alt2 -> + u1 <- do + ccp.0 <- case v1 of + (Ffoo a1) @ alt20 -> + y1' <- foo a1 + pure y1' + (Fbar b1) @ alt21 -> + z1' <- bar b1 + pure z1' + (CInt x1') @ alt22 -> + pure x1' + pure (CInt ccp.0) + pure (CInt x') + pure m0 + |] + + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teAfter, after))), NewNames) + + it "last expression is a case" $ do + let teBefore = create $ + newVar "ax'" int64_t + let before = + [expr| + l2 <- eval l + case l2 of + (CNil) @ alt0 -> + k0 <- pure 0 + pure (CInt k0) + (CCons x xs) @ alt1 -> + (CInt x') @ v0 <- eval x + (CInt s') @ v1 <- sum xs + ax' <- _prim_int_add x' s' + pure (CInt ax') + |] + let teAfter = extend teBefore $ + newVar "l2'" int64_t + let after = + [expr| + l2 <- eval l + do ccp.0 <- case l2 of + (CNil) @ alt0 -> + k0 <- pure 0 + pure k0 + (CCons x xs) @ alt1 -> + (CInt x') @ v0 <- eval x + (CInt s') @ v1 <- sum xs + ax' <- _prim_int_add x' s' + pure ax' + pure (CInt ccp.0) + |] + (caseCopyPropagation (snd (ctx (teBefore, before)))) `sameAs` ((snd (ctx (teAfter, after))), NewNames) + +runTests :: IO () +runTests = hspec spec diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs new file mode 100644 index 00000000..f161b63f --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec where + +import Transformations.ExtendedSyntax.Optimising.CaseHoisting + +import Test.Hspec + +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeCheck +import Test.ExtendedSyntax.Assertions +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "last case" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + case u of + (CNil) @ alt3 -> pure alt3 + (CCons a2 b2) @ alt4 -> pure (CNil) + |] + let after = [prog| + grinMain = + v <- pure (CNil) + case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure alt3.0 + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure (CNil) + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "middle case" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + r <- case u of + (CNil) @ alt3 -> pure 1 + (CCons a2 b2) @ alt4 -> pure 2 + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + r <- case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure 1 + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure 2 + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "default pattern" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + r <- case u of + (CNil) @ alt3 -> pure (CNil) + #default @ alt4 -> pure alt4 + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + r <- case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure (CNil) + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure alt4.0 + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "case chain + no code duplication" $ do + let before = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> pure (CNil) + 1 @ alt2 -> pure (CCons v v) + r <- case u of + (CNil) @ alt3 -> pure (CEmpty) + #default @ alt4 -> pure u + q <- case r of + (CVoid) @ alt5 -> + pure (CEmpty) + #default @ alt6 -> + k0 <- pure 777 + _1 <- _prim_int_print k0 + pure r + pure q + |] + let after = [prog| + grinMain = + v <- pure 1 + r <- case v of + 0 @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure (CEmpty) + 1 @ alt2 -> + u.1 <- do + pure (CCons v v) + alt4.0 <- pure u.1 + pure u.1 + q <- case r of + (CVoid) @ alt5 -> + pure (CEmpty) + #default @ alt6 -> + k0 <- pure 777 + _1 <- _prim_int_print k0 + pure r + pure q + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "default chain" $ do + let before = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> pure (CNil) + 1 @ alt2 -> pure (CCons v v) + r <- case u of + #default @ alt3 -> pure u + q <- case r of + #default @ alt4 -> pure r + pure q + |] + let after = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> + pure (CNil) + 1 @ alt2 -> + pure (CCons v v) + q <- case u of + #default @ alt3 -> + r.0 <- do + pure u + alt4.0 <- pure r.0 + pure r.0 + pure q + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "ignore non linear variable" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + #default @ alt1 -> pure v + r <- case u of + #default @ alt2 -> pure u + x <- pure u + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + #default @ alt1 -> pure v + r <- case u of + #default @ alt2 -> pure u + x <- pure u + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NoChange) diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/ConstantPropagationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/ConstantPropagationSpec.hs new file mode 100644 index 00000000..6f204eda --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/ConstantPropagationSpec.hs @@ -0,0 +1,243 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.ConstantPropagationSpec where + +import Transformations.ExtendedSyntax.Optimising.ConstantPropagation + +import Test.Hspec + +import Grin.ExtendedSyntax.TH +import Test.ExtendedSyntax.Assertions + + +runTests :: IO () +runTests = hspec spec + + +spec :: Spec +spec = do + it "ignores binds" $ do + let before = [expr| + i1 <- pure 1 + i2 <- pure i1 + n1 <- pure (CNode i2) + n2 <- pure n1 + (CNode i3) @ n3 <- pure n1 + pure 2 + |] + let after = [expr| + i1 <- pure 1 + i2 <- pure i1 + n1 <- pure (CNode i2) + n2 <- pure n1 + (CNode i3) @ n3 <- pure n1 + pure 2 + |] + constantPropagation before `sameAs` after + + it "is not interprocedural" $ do + let before = [prog| + grinMain = + x <- f + case x of + (COne) @ alt1 -> pure 0 + (CTwo) @ alt2 -> pure 1 + + f = pure (COne) + |] + let after = [prog| + grinMain = + x <- f + case x of + (COne) @ alt1 -> pure 0 + (CTwo) @ alt2 -> pure 1 + + f = pure (COne) + |] + constantPropagation before `sameAs` after + + it "does not propagate info outwards of case expressions" $ do + let before = [prog| + grinMain = + x <- pure 0 + y <- case x of + 0 @ alt1 -> pure (COne) + case y of + (COne) @ alt2 -> pure 0 + (CTwo) @ alt3 -> pure 1 + |] + let after = [prog| + grinMain = + x <- pure 0 + y <- case x of + 0 @ alt1 -> pure (COne) + case y of + (COne) @ alt2 -> pure 0 + (CTwo) @ alt3 -> pure 1 + |] + constantPropagation before `sameAs` after + + it "base case" $ do + let before = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + case n1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + |] + let after = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + do + (CNode a1) @ alt2 <- pure (CNode i1) + pure 2 + |] + constantPropagation before `sameAs` after + + it "ignores illformed case - multi matching" $ do + let before = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + _1 <- case n1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + (CNode b1) @ alt3 -> pure 3 + case n1 of + (CNil) @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + #default @ alt6 -> pure 6 + |] + let after = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + _1 <- case n1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + (CNode b1) @ alt3 -> pure 3 + case n1 of + (CNil) @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + #default @ alt6 -> pure 6 + |] + constantPropagation before `sameAs` after + + it "default pattern" $ do + let before = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + case n1 of + (CNil) @ alt1 -> pure 2 + #default @ alt2 -> pure 3 + |] + let after = [expr| + i1 <- pure 1 + n1 <- pure (CNode i1) + do + alt2 <- pure n1 + pure 3 + |] + constantPropagation before `sameAs` after + + it "unknown scrutinee - simple" $ do + let before = [expr| + case n1 of + (CNil) @ alt1 -> pure 2 + #default @ alt2 -> pure 3 + |] + let after = [expr| + case n1 of + (CNil) @ alt1 -> pure 2 + #default @ alt2 -> pure 3 + |] + constantPropagation before `sameAs` after + + it "unknown scrutinee becomes known in alternatives - specific pattern" $ do + let before = [expr| + case n1 of + (CNil) @ alt11 -> + case n1 of + (CNil) @ alt21 -> pure 1 + (CNode a1) @ alt22 -> pure 2 + (CNode a2) @ alt12 -> + case n1 of + (CNil) @ alt23 -> pure 3 + (CNode a3) @ alt24 -> pure 4 + |] + let after = [expr| + case n1 of + (CNil) @ alt11 -> + do + (CNil) @ alt21 <- pure (CNil) + pure 1 + (CNode a2) @ alt12 -> + do + (CNode a3) @ alt24 <- pure (CNode a2) + pure 4 + |] + constantPropagation before `sameAs` after + + it "unknown scrutinee becomes known in alternatives - default pattern" $ do + let before = [expr| + case n1 of + #default @ alt11 -> + case n1 of + #default @ alt21 -> pure 1 + (CNode a1) @ alt22 -> pure 2 + (CNode a2) @ alt12 -> + case n1 of + #default @ alt23 -> pure 3 + (CNode a3) @ alt24 -> pure 4 + |] + let after = [expr| + case n1 of + #default @ alt11 -> + do + alt21 <- pure n1 + pure 1 + (CNode a2) @ alt12 -> + do + (CNode a3) @ alt24 <- pure (CNode a2) + pure 4 + |] + constantPropagation before `sameAs` after + + it "literal - specific pattern" $ do + let before = [expr| + i1 <- pure 1 + case i1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + 1 @ alt3 -> pure 3 + 2 @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + |] + let after = [expr| + i1 <- pure 1 + case i1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + 1 @ alt3 -> pure 3 + 2 @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + |] + constantPropagation before `sameAs` after + + it "literal - default pattern" $ do + let before = [expr| + i1 <- pure 3 + case i1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + 1 @ alt3 -> pure 3 + 2 @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + |] + let after = [expr| + i1 <- pure 3 + case i1 of + (CNil) @ alt1 -> pure 1 + (CNode a1) @ alt2 -> pure 2 + 1 @ alt3 -> pure 3 + 2 @ alt4 -> pure 4 + #default @ alt5 -> pure 5 + |] + constantPropagation before `sameAs` after diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/CopyPropagationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/CopyPropagationSpec.hs index 40adc464..3e46cde1 100644 --- a/grin/test/Transformations/ExtendedSyntax/Optimising/CopyPropagationSpec.hs +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/CopyPropagationSpec.hs @@ -59,7 +59,37 @@ spec = do |] copyPropagation (ctx before) `sameAs` (ctx after) - it "node value - node pattern" $ do + it "node value - node pattern 1" $ do + let before = [expr| + a1 <- pure 1 + b1 <- pure 0 + (CNode a2 b2) @ n1 <- pure (CNode a1 b1) + foo n1 + |] + let after = [expr| + a1 <- pure 1 + b1 <- pure 0 + n1 <- pure (CNode a1 b1) + foo n1 + |] + copyPropagation (ctx before) `sameAs` (ctx after) + + it "node value - node pattern 2" $ do + let before = [expr| + a1 <- pure 1 + b1 <- pure 0 + (CNode a2 b2) @ n1 <- pure (CNode a1 b1) + foo a2 + |] + let after = [expr| + a1 <- pure 1 + b1 <- pure 0 + n1 <- pure (CNode a1 b1) + foo a1 + |] + copyPropagation (ctx before) `sameAs` (ctx after) + + it "node value - node pattern 3" $ do let before = [expr| a1 <- pure 1 b1 <- pure 0 diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/DeadDataEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/DeadDataEliminationSpec.hs new file mode 100644 index 00000000..b1b541d3 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/DeadDataEliminationSpec.hs @@ -0,0 +1,736 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes #-} +module Transformations.ExtendedSyntax.Optimising.DeadDataEliminationSpec where + +import Transformations.ExtendedSyntax.Optimising.DeadDataElimination + +import qualified Data.Map as Map +import qualified Data.Set as Set + +import Test.Hspec + +import Test.ExtendedSyntax.Util +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeCheck (inferTypeEnv) +import Grin.ExtendedSyntax.PrimOpsPrelude (withPrimPrelude) +import AbstractInterpretation.ExtendedSyntax.CreatedBy.Result (CByResult(..), ProducerGraph) +import AbstractInterpretation.ExtendedSyntax.CreatedBy.Util (groupActiveProducers, groupAllProducers, toProducerGraph) +import AbstractInterpretation.ExtendedSyntax.CreatedBySpec (calcCByResult) +import AbstractInterpretation.ExtendedSyntax.LiveVariableSpec (calcLiveness) + + +runTests :: IO () +runTests = hspec spec + +dde :: Exp -> Exp +dde e = fst $ either error id $ + deadDataElimination (calcLiveness e) (calcCByResult e) (inferTypeEnv e) e + +spec :: Spec +spec = do + describe "Dead Data Elimination" $ do + it "Impossible alternative" $ do + let before = [prog| + grinMain = + a0 <- pure 5 + n0 <- pure (CInt a0) + r <- case n0 of + (CInt c0) @ alt1 -> pure 0 + (CBool c1) @ alt2 -> pure 0 + pure r + |] + + let after = [prog| + grinMain = + a0 <- pure 5 + n0 <- pure (CInt.0) + r <- case n0 of + (CInt.0) @ alt1 -> + c0 <- pure (#undefined :: T_Int64) + pure 0 + (CBool.0) @ alt2 -> + c1 <- pure (#undefined :: T_Dead) + pure 0 + pure r + |] + dde before `sameAs` after + + it "As-Pattern Simple 1" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + n0 <- pure (CInt a0) + (CInt b1) @ n1 <- pure n0 + (CInt b2) @ n2 <- pure n0 + pure b2 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + n0 <- pure (CInt a0) + (CInt b1) @ n1 <- pure n0 + (CInt b2) @ n2 <- pure n0 + pure b2 + |] + dde before `sameAs` after + + it "As-Pattern Simple 2" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + (CInt b0) @ n0 <- pure (CInt a0) + (CInt b1) @ n1 <- pure n0 + pure b0 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + (CInt b0) @ n0 <- pure (CInt a0) + (CInt b1) @ n1 <- pure n0 + pure b0 + |] + dde before `sameAs` after + + it "As-Pattern Deletable" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + (CInt b0) @ n0 <- pure (CInt a0) + (CInt b1) @ n1 <- pure n0 + pure 0 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + (CInt.0) @ n0 <- pure (CInt.0) + b0 <- pure (#undefined :: T_Int64) + (CInt.0) @ n1 <- pure n0 + b1 <- pure (#undefined :: T_Int64) + pure 0 + |] + dde before `sameAs` after + + it "As-Pattern Fetch" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + n0 <- pure (CInt a0) + p0 <- store n0 + (CInt a1) @ n1 <- fetch p0 + case n0 of + (CInt a2) @ alt1 -> pure a2 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + n0 <- pure (CInt a0) + p0 <- store n0 + (CInt a1) @ n1 <- fetch p0 + case n0 of + (CInt a2) @ alt1 -> pure a2 + |] + dde before `sameAs` after + + it "Case Consumers" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + n0' <- pure (CInt a0) + case n0' of + (CInt a1) @ n0 -> + case n0 of + (CInt a2) @ alt1 -> pure a2 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + n0' <- pure (CInt a0) + case n0' of + (CInt a1) @ n0 -> + case n0 of + (CInt a2) @ alt1 -> pure a2 + |] + dde before `sameAs` after + + it "Case Consumers Dummifiable" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree a3 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + case n0 of + (CThree b0 b1 b2) @ _1 -> + case n01 of + (CThree c0 c1 c2) @ _2 -> + case n1 of + (CThree d0 d1 d2) @ _3 -> + pure (CLive b0 c1 d2) + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + a2.0 <- pure (#undefined :: T_Int64) + n0 <- pure (CThree a0 a1 a2.0) + a3.0 <- pure (#undefined :: T_Int64) + n1 <- pure (CThree a3.0 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + case n0 of + (CThree b0 b1 b2) @ _1 -> + case n01 of + (CThree c0 c1 c2) @ _2 -> + case n1 of + (CThree d0 d1 d2) @ _3 -> + pure (CLive b0 c1 d2) + |] + dde before `sameAs` after + + it "Case Consumers Deletable" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + n0' <- pure (CInt a0) + case n0' of + (CInt a1) @ n0 -> + case n0 of + (CInt a2) @ alt1 -> pure 0 + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + n0' <- pure (CInt.0) + case n0' of + (CInt.0) @ n0 -> + a1 <- pure (#undefined :: T_Int64) + case n0 of + (CInt.0) @ alt1 -> + a2 <- pure (#undefined :: T_Int64) + pure 0 + |] + dde before `sameAs` after + + it "Multiple fields" $ do + let before = withPrimPrelude [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + n0 <- pure (CThree a0 a1 a2 a3 a4 a5) + (CThree b0 b1 b2 b3 b4 b5) @ _1 <- pure n0 + r <- _prim_int_add b1 b4 + pure r + |] + + let after = withPrimPrelude [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + n0 <- pure (CThree.0 a1 a4) + (CThree.0 b1 b4) @ _1 <- pure n0 + b5 <- pure (#undefined :: T_Int64) + b3 <- pure (#undefined :: T_Int64) + b2 <- pure (#undefined :: T_Int64) + b0 <- pure (#undefined :: T_Int64) + r <- _prim_int_add b1 b4 + pure r + |] + dde before `sameAs` after + + it "Only dummify" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree a3 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree b0 b1 b2) @ _1 <- pure n0 + (CThree c0 c1 c2) @ _2 <- pure n01 + (CThree d0 d1 d2) @ _3 <- pure n1 + + r <- pure (CLive b0 c1 d2) + pure r + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + a2.0 <- pure (#undefined :: T_Int64) + n0 <- pure (CThree a0 a1 a2.0) + a3.0 <- pure (#undefined :: T_Int64) + n1 <- pure (CThree a3.0 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree b0 b1 b2) @ _1 <- pure n0 + (CThree c0 c1 c2) @ _2 <- pure n01 + (CThree d0 d1 d2) @ _3 <- pure n1 + + r <- pure (CLive b0 c1 d2) + pure r + |] + dde before `sameAs` after + + it "Deletable Single" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree a3 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree b0 b1 b2) @ _1 <- pure n0 + (CThree c0 c1 c2) @ _2 <- pure n01 + (CThree d0 d1 d2) @ _3 <- pure n1 + + r <- pure (CLive b0 d2) + pure r + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + a2.0 <- pure (#undefined :: T_Int64) + n0 <- pure (CThree.0 a0 a2.0) + a3.0 <- pure (#undefined :: T_Int64) + n1 <- pure (CThree.0 a3.0 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree.0 b0 b2) @ _1 <- pure n0 + b1 <- pure (#undefined :: T_Int64) + (CThree.0 c0 c2) @ _2 <- pure n01 + c1 <- pure (#undefined :: T_Int64) + (CThree.0 d0 d2) @ _3 <- pure n1 + d1 <- pure (#undefined :: T_Int64) + + r <- pure (CLive b0 d2) + pure r + |] + dde before `sameAs` after + + it "Deletable Multi" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree a3 a4 a5) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree b0 b1 b2) @ _1 <- pure n0 + (CThree c0 c1 c2) @ _2 <- pure n01 + (CThree d0 d1 d2) @ _3 <- pure n1 + + r <- pure (CLive c1) + pure r + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + a3 <- pure 0 + a4 <- pure 0 + a5 <- pure 0 + + -- two producers + n0 <- pure (CThree.0 a1) + n1 <- pure (CThree.0 a4) + + -- n01 has producers: n0, n1 + s0 <- pure 0 + n01 <- case s0 of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + + -- consumers + (CThree.0 b1) @ _1 <- pure n0 + b2 <- pure (#undefined :: T_Int64) + b0 <- pure (#undefined :: T_Int64) + + (CThree.0 c1) @ _2 <- pure n01 + c2 <- pure (#undefined :: T_Int64) + c0 <- pure (#undefined :: T_Int64) + + (CThree.0 d1) @ _3 <- pure n1 + d2 <- pure (#undefined :: T_Int64) + d0 <- pure (#undefined :: T_Int64) + + r <- pure (CLive c1) + pure r + |] + dde before `sameAs` after + + it "Separate Producers" $ do + let before = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree a0 a1 a2) + n2 <- pure (CThree a0 a1 a2) + n3 <- pure (CThree a0 a1 a2) + n4 <- pure (CThree a0 a1 a2) + n5 <- pure (CThree a0 a1 a2) + n6 <- pure (CThree a0 a1 a2) + n7 <- pure (CThree a0 a1 a2) + + -- consumers + (CThree b0 b1 b2) @ _1 <- pure n0 + (CThree c0 c1 c2) @ _2 <- pure n1 + (CThree d0 d1 d2) @ _3 <- pure n2 + (CThree e0 e1 e2) @ _4 <- pure n3 + (CThree f0 f1 f2) @ _5 <- pure n4 + (CThree g0 g1 g2) @ _6 <- pure n5 + (CThree h0 h1 h2) @ _7 <- pure n6 + (CThree i0 i1 i2) @ _8 <- pure n7 + + r <- pure (CLive b0 b1 b2 c0 d1 e2 f0 f1 g1 g2 h0 h2) + pure r + |] + + let after = [prog| + grinMain = + a0 <- pure 0 + a1 <- pure 0 + a2 <- pure 0 + + n0 <- pure (CThree a0 a1 a2) + n1 <- pure (CThree.6 a0) + n2 <- pure (CThree.5 a1) + n3 <- pure (CThree.4 a2) + n4 <- pure (CThree.3 a0 a1) + n5 <- pure (CThree.2 a1 a2) + n6 <- pure (CThree.1 a0 a2) + n7 <- pure (CThree.0) + + (CThree b0 b1 b2) @ _1 <- pure n0 + + (CThree.6 c0) @ _2 <- pure n1 + c2 <- pure (#undefined :: T_Int64) + c1 <- pure (#undefined :: T_Int64) + + (CThree.5 d1) @ _3 <- pure n2 + d2 <- pure (#undefined :: T_Int64) + d0 <- pure (#undefined :: T_Int64) + + (CThree.4 e2) @ _4 <- pure n3 + e1 <- pure (#undefined :: T_Int64) + e0 <- pure (#undefined :: T_Int64) + + (CThree.3 f0 f1) @ _5 <- pure n4 + f2 <- pure (#undefined :: T_Int64) + + (CThree.2 g1 g2) @ _6 <- pure n5 + g0 <- pure (#undefined :: T_Int64) + + (CThree.1 h0 h2) @ _7 <- pure n6 + h1 <- pure (#undefined :: T_Int64) + + (CThree.0) @ _8 <- pure n7 + i2 <- pure (#undefined :: T_Int64) + i1 <- pure (#undefined :: T_Int64) + i0 <- pure (#undefined :: T_Int64) + + r <- pure (CLive b0 b1 b2 c0 d1 e2 f0 f1 g1 g2 h0 h2) + pure r + |] + dde before `sameAs` after + + it "FNode" $ do + let before = [prog| + grinMain = + k0 <- pure 5 + x0 <- pure (CInt k0) + p0 <- store x0 + a0 <- pure (Ffoo p0 p0 p0) + p1 <- store a0 + a1 <- eval p1 + pure a1 + + -- functions cannot return pointers + foo x y z = + y' <- eval y + pure y' + + eval p = + v <- fetch p + case v of + (CInt n) @ alt1 -> + pure v + (Ffoo x1 y1 z1) @ alt2 -> + w <- foo x1 y1 z1 + _1 <- update p w + pure w + |] + + let after = [prog| + grinMain = + k0 <- pure 5 + x0 <- pure (CInt k0) + p0 <- store x0 + a0 <- pure (Ffoo.0 p0) + p1 <- store a0 + a1 <- eval p1 + pure a1 + + foo x y z = + y' <- eval y + pure y' + + eval p = + v <- fetch p + case v of + (CInt n) @ alt1 -> + pure v + (Ffoo.0 y1) @ alt2 -> + z1 <- pure (#undefined :: #ptr) + x1 <- pure (#undefined :: #ptr) + w <- foo x1 y1 z1 + _1 <- update p w + pure w + |] + dde before `sameAs` after + + it "PNode" $ do + before <- loadTestData "dead-data-elimination/pnode_before.grin" + after <- loadTestData "dead-data-elimination/pnode_after.grin" + dde before `sameAs` after + + it "PNode Opt" $ do + let before = [prog| + grinMain = + k0 <- pure 0 + a0 <- pure (CInt k0) + a1 <- pure (CInt k0) + a2 <- pure (CInt k0) + p0 <- store a0 + p1 <- store a1 + p2 <- store a2 + + foo3 <- pure (P3foo) + + (P3foo) @ _1 <- pure foo3 + foo2 <- pure (P2foo p0) + + (P2foo v0) @ _2 <- pure foo2 + foo1 <- pure (P1foo v0 p1) + + (P1foo v1 v2) @ _3 <- pure foo1 + fooRet <- foo v1 v2 p2 + + pure fooRet + + + foo x0 y0 z0 = + y0' <- fetch y0 + (CInt n) @ _4 <- pure y0' + pure y0' + |] + + let after = [prog| + {- NOTE: + P2foo is renamed to P2foo.1 because the name generation + takes the node name as base. So here foo will be the base, + for which a name was already generated: P1foo.0. + -} + grinMain = + k0 <- pure 0 + a0 <- pure (CInt.0) + a1 <- pure (CInt k0) + a2 <- pure (CInt.0) + p0 <- store a0 + p1 <- store a1 + p2 <- store a2 + + foo3 <- pure (P3foo) + + (P3foo) @ _1 <- pure foo3 + foo2 <- pure (P2foo.1) + + (P2foo.1) @ _2 <- pure foo2 + v0 <- pure (#undefined :: #ptr) + foo1 <- pure (P1foo.0 p1) + + (P1foo.0 v2) @ _3 <- pure foo1 + v1 <- pure (#undefined :: #ptr) + fooRet <- foo v1 v2 p2 + + pure fooRet + + + foo x0 y0 z0 = + y0' <- fetch y0 + (CInt n) @ _4 <- pure y0' + pure y0' + |] + dde before `sameAs` after + + it "Length" $ do + before <- loadTestData "dead-data-elimination/length_before.grin" + after <- loadTestData "dead-data-elimination/length_after.grin" + dde before `sameAs` after + + describe "Producer Grouping" $ do + let exp = [prog| + grinMain = + k0 <- pure 0 + n0 <- pure (CInt k0) + n1 <- pure (CBool k0) + n2 <- pure (CBool k0) + n3 <- pure (CBool k0) + s <- pure 5 + n01 <- case s of + 0 @ alt1 -> pure n0 + 1 @ alt2 -> pure n1 + n12 <- case s of + 0 @ alt3 -> pure n1 + 1 @ alt4 -> pure n2 + n23 <- case s of + 0 @ alt5 -> pure n2 + 1 @ alt6 -> pure n3 + z0 <- case n01 of + (CInt c0) @ alt7 -> pure 5 + (CBool c1) @ alt8 -> pure 5 + (CBool z1) @ _1 <- case n12 of + (CInt c2) @ alt9 -> pure 5 + (CBool c3) @ alt10 -> pure 5 + (CBool z2) @ _2 <- pure n23 + pure 5 + |] + + it "multi_prod_simple_all" $ do + let multiProdSimpleAllExpected = mkGraph + [ ("n0", [ (cInt, ["n0"]) ] ) + , ("n1", [ (cBool, ["n1", "n2", "n3"]) ]) + , ("n2", [ (cBool, ["n1", "n2", "n3"]) ]) + , ("n3", [ (cBool, ["n1", "n2", "n3"]) ]) + ] + found = groupAllProducers . _producers . calcCByResult $ exp + found `shouldBe` multiProdSimpleAllExpected + + it "multi_prod_simple_active" $ do + let multiProdSimpleActiveExpected = mkGraph + [ ("n0", [ (cInt, ["n0"]) ]) + , ("n1", [ (cBool, ["n1"]) ]) + , ("n2", [ (cBool, ["n2"]) ]) + , ("n3", [ (cBool, ["n3"]) ]) + ] + let found = groupActiveProducers <$> calcLiveness <*> (_producers . calcCByResult) $ exp + found `shouldBe` multiProdSimpleActiveExpected + +mkGraph :: [ (Name, [(Tag, [Name])]) ] -> ProducerGraph +mkGraph = toProducerGraph + . Map.map (Map.map Set.fromList) + . Map.map Map.fromList + . Map.fromList + diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/DeadParameterEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/DeadParameterEliminationSpec.hs new file mode 100644 index 00000000..a7cf8cc5 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/DeadParameterEliminationSpec.hs @@ -0,0 +1,217 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes #-} +module Transformations.ExtendedSyntax.Optimising.DeadParameterEliminationSpec where + +import Transformations.ExtendedSyntax.Optimising.DeadParameterElimination (deadParameterElimination) + +import Data.Either + +import Test.Hspec + +import Test.ExtendedSyntax.Util (loadTestData) +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.PrimOpsPrelude (withPrimPrelude) +import Grin.ExtendedSyntax.TypeCheck (inferTypeEnv) +import AbstractInterpretation.ExtendedSyntax.LiveVariableSpec (calcLiveness) + + +runTests :: IO () +runTests = hspec spec + +dpe :: Exp -> Exp +dpe e = either error id $ + deadParameterElimination (calcLiveness e) (inferTypeEnv e) e + +spec :: Spec +spec = do + describe "Dead Parameter Elimination" $ do + + it "Fnode" $ do + let before = [prog| + grinMain = + k0 <- pure 5 + x0 <- pure (CInt k0) + p0 <- store x0 + a0 <- pure (Ffoo p0 p0 p0) + p1 <- store a0 + a1 <- eval p1 + pure a1 + + -- functions cannot return pointers + foo x y z = + y' <- eval y + pure y' + + eval p = + v <- fetch p + case v of + (CInt n) @ alt1 -> pure v + (Ffoo x1 y1 z1) @ alt2 -> + w <- foo x1 y1 z1 + _1 <- update p w + pure w + |] + + let after = [prog| + grinMain = + k0 <- pure 5 + x0 <- pure (CInt k0) + p0 <- store x0 + a0 <- pure (Ffoo p0 p0 p0) + p1 <- store a0 + a1 <- eval p1 + pure a1 + + -- functions cannot return pointers + foo y = + z <- pure (#undefined :: #ptr) + x <- pure (#undefined :: #ptr) + y' <- eval y + pure y' + + eval p = + v <- fetch p + case v of + (CInt n) @ alt1 -> pure v + (Ffoo x1 y1 z1) @ alt2 -> + w <- foo y1 + _1 <- update p w + pure w + |] + dpe before `sameAs` after + + -- TODO: reenable + -- it "Pnode" $ pipeline + -- "dead-parameter-elimination/pnode_before.grin" + -- "dead-parameter-elimination/pnode_after.grin" + -- deadParameterEliminationPipeline + + + it "PNode" $ do + before <- loadTestData "dead-parameter-elimination/pnode_before.grin" + after <- loadTestData "dead-parameter-elimination/pnode_after.grin" + dpe before `sameAs` after + + it "Pnode opt" $ do + let before = [prog| + grinMain = + k0 <- pure 5 + a0 <- pure (CInt k0) + a1 <- pure (CInt k0) + a2 <- pure (CInt k0) + p0 <- store a0 + p1 <- store a1 + p2 <- store a2 + + foo3 <- pure (P3foo) + + (P3foo) @ _1 <- pure foo3 + foo2 <- pure (P2foo p0) + + (P2foo v0) @ _2 <- pure foo2 + foo1 <- pure (P1foo v0 p1) + + (P1foo v1 v2) @ _3 <- pure foo1 + fooRet <- foo v1 v2 p2 + pure fooRet + + foo x0 y0 z0 = + y0' <- fetch y0 + (CInt n) @ _4 <- y0' + pure y0' + |] + + let after = [prog| + grinMain = + k0 <- pure 5 + a0 <- pure (CInt k0) + a1 <- pure (CInt k0) + a2 <- pure (CInt k0) + p0 <- store a0 + p1 <- store a1 + p2 <- store a2 + + foo3 <- pure (P3foo) + + (P3foo) @ _1 <- pure foo3 + foo2 <- pure (P2foo p0) + + (P2foo v0) @ _2 <- pure foo2 + foo1 <- pure (P1foo v0 p1) + + (P1foo v1 v2) @ _3 <- pure foo1 + fooRet <- foo v2 + pure fooRet + + foo y0 = + z0 <- pure (#undefined :: #ptr) + x0 <- pure (#undefined :: #ptr) + y0' <- fetch y0 + (CInt n) @ _4 <- y0' + pure y0' + |] + dpe before `sameAs` after + + it "Simple" $ do + let before = [prog| + grinMain = + k0 <- pure 5 + g k0 + + f x y = pure x + + g z = + k1 <- pure 0 + f k1 z + |] + + let after = [prog| + grinMain = + k0 <- pure 5 + g + + f x = + y <- pure (#undefined :: T_Int64) + pure x + + g = + z <- pure (#undefined :: T_Int64) + k1 <- pure 0 + f k1 + |] + dpe before `sameAs` after + + it "Mutually recursive" $ do + let before = [prog| + grinMain = + k0 <- pure 0 + f k0 k0 + + f x y = + k1 <- pure 0 + g x k1 + + g v w = + k2 <- pure 0 + f k2 w + |] + + let after = [prog| + grinMain = + k0 <- pure 0 + f + + f = + y <- pure (#undefined :: T_Int64) + x <- pure (#undefined :: T_Int64) + k1 <- pure 0 + g + + g = + w <- pure (#undefined :: T_Int64) + v <- pure (#undefined :: T_Int64) + k2 <- pure 0 + f + |] + dpe before `sameAs` after diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs new file mode 100644 index 00000000..3dd65029 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/GeneralizedUnboxingSpec.hs @@ -0,0 +1,398 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxingSpec where + +import Transformations.ExtendedSyntax.Optimising.GeneralizedUnboxing + + +import qualified Data.Set as Set +import qualified Data.Map.Strict as Map +import qualified Data.Vector as Vector + +import Test.Hspec + +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "Figure 4.21 (extended)" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2B", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2C", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo3", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo4", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo5", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + } + let before = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + foo2 a1 a2 a3 = + c1 <- prim_int_add a1 a2 + foo c1 c1 a3 + + foo2B a1 a2 a3 = + c1 <- prim_int_add a1 a2 + do + foo c1 c1 a3 + + foo2C a1 a2 a3 = + c1 <- prim_int_add a1 a2 + case c1 of + #default @ alt1 -> pure c1 + (CInt x1) @ alt2 -> foo c1 c1 a3 + + foo3 a1 a2 a3 = + c1 <- prim_int_add a1 a2 + -- In this case the vectorisation did not happen. + c2 <- foo c1 c1 a3 + pure c2 + + foo4 a1 = + v <- pure (CInt a1) + pure v + + foo5 a1 = + n0 <- pure (CInt a1) + p <- store n0 + fetch p + + bar = + k1 <- pure 1 + n1 <- test k1 + (CInt y') @ _0 <- foo a1 a2 a3 + test y' + |] + let teAfter = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2B.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo2C.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo3.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo4.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("foo5.unboxed", (int64_t, Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + , _variable = Map.fromList + [ ("unboxed.CInt.0", int64_t) + , ("unboxed.CInt.1", int64_t) + , ("unboxed.CInt.2", int64_t) + , ("unboxed.CInt.3", int64_t) + , ("unboxed.CInt.4", int64_t) + , ("unboxed.CInt.5", int64_t) + ] + } + let after = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo.unboxed a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure b2 + + foo2.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + foo.unboxed c1 c1 a3 + + foo2B.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + do + foo.unboxed c1 c1 a3 + + foo2C.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + case c1 of + #default @ alt1 -> + do + (CInt unboxed.CInt.0) @ _1 <- pure c1 + pure unboxed.CInt.0 + (CInt x1) @ alt2 -> + foo.unboxed c1 c1 a3 + + foo3.unboxed a1 a2 a3 = + c1 <- prim_int_add a1 a2 + c2 <- do + unboxed.CInt.4 <- foo.unboxed c1 c1 a3 + pure (CInt unboxed.CInt.4) + do + (CInt unboxed.CInt.1) @ _2 <- pure c2 + pure unboxed.CInt.1 + + foo4.unboxed a1 = + v <- pure (CInt a1) + do + (CInt unboxed.CInt.2) @ _3 <- pure v + pure unboxed.CInt.2 + + foo5.unboxed a1 = + n0 <- pure (CInt a1) + p <- store n0 + do + (CInt unboxed.CInt.3) @ _4 <- fetch p + pure unboxed.CInt.3 + + bar = + k1 <- pure 1 + n1 <- test k1 + (CInt y') @ _0 <- do + unboxed.CInt.5 <- foo.unboxed a1 a2 a3 + pure (CInt unboxed.CInt.5) + test y' + |] + generalizedUnboxing teBefore before `sameAs` (after, NewNames) + + it "Return values are in cases" $ do + let teBefore = emptyTypeEnv + { _function = + fun_t "int_eq" + [ T_NodeSet $ cnode_t "Int" [T_Int64] + , T_NodeSet $ cnode_t "Int" [T_Int64] + ] + (T_NodeSet $ cnode_t "Int" [T_Int64]) + , _variable = Map.fromList + [ ("eq0", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq1", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq0_1", int64_t) + , ("eq1_1", int64_t) + , ("eq2", bool_t) + ] + } + let before = [prog| + int_eq eq0 eq1 = + (CInt eq0_1) @ alt1 <- fetch eq0 + (CInt eq1_1) @ alt2 <- fetch eq1 + eq2 <- _prim_int_eq eq0_1 eq1_1 + case eq2 of + #False @ alt3 -> + k0 <- pure 0 + pure (CInt k0) + #True @ alt4 -> + k1 <- pure 1 + pure (CInt k1) + |] + let teAfter = emptyTypeEnv + { _function = + fun_t "int_eq.unboxed" + [ T_NodeSet $ cnode_t "Int" [T_Int64] + , T_NodeSet $ cnode_t "Int" [T_Int64] + ] + int64_t + , _variable = Map.fromList + [ ("eq0", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq1", T_NodeSet $ cnode_t "Int" [T_Int64]) + , ("eq0_1", int64_t) + , ("eq1_1", int64_t) + , ("eq2", bool_t) + ] + } + let after = [prog| + int_eq.unboxed eq0 eq1 = + (CInt eq0_1) @ alt1 <- fetch eq0 + (CInt eq1_1) @ alt2 <- fetch eq1 + eq2 <- _prim_int_eq eq0_1 eq1_1 + case eq2 of + #False @ alt3 -> + k0 <- pure 0 + pure k0 + #True @ alt4 -> + k1 <- pure 1 + pure k1 + |] + generalizedUnboxing teBefore before `sameAs` (after, NewNames) + + it "Step 1 for Figure 4.21" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("test", (int64_t, Vector.fromList [int64_t])) + , ("foo", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("bar", (int64_t, Vector.fromList [])) + ] + } + let before = [prog| + test n = + k0 <- pure 1 + prim_int_add n k0 + + foo a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + bar = + k1 <- pure 1 + n <- test k1 + (CInt y') @ _1 <- foo a1 a2 a3 + test y' + |] + functionsToUnbox teBefore before `shouldBe` (Set.fromList ["foo"]) + + it "Tail calls and general unboxing" $ do + let teBefore = emptyTypeEnv + { _function = Map.fromList + [ ("inside1", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t, int64_t, int64_t])) + , ("outside3", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ,(Tag C "Nat", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside4", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside2", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + , ("outside1", (T_NodeSet + (Map.fromList + [(Tag C "Int", Vector.fromList [T_Int64]) + ]) + , Vector.fromList [int64_t])) + ] + } + let before = [prog| + inside1 a1 a2 a3 = + b1 <- prim_int_add a1 a2 + b2 <- prim_int_add b1 a3 + pure (CInt b2) + + outside4 = + k0 <- pure () + k1 <- pure 1 + _1 <- pure k0 + outside3 k2 + + outside3 p1 = + case p1 of + 1 @ alt1 -> inside1 p1 p1 p1 -- :: CInt Int + 2 @ alt2 -> outside2 p1 -- :: CNat Int + + outside2 p1 = + k0 <- pure () + k1 <- pure 1 + _2 <- pure k0 + outside1 p1 + + outside1 p1 = + k2 <- pure 1 + y <- prim_int_add p1 k2 + x <- pure (CNat y) + pure x + |] + functionsToUnbox teBefore before `shouldBe` mempty + + it "Tail call function 1" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 3 + tail k0 + |] + tailCalls fun `shouldBe` (Just ["tail"]) + + it "Tail call function 2" $ do + let fun = [def| + fun x = + l <- pure x + k0 <- pure 1 + case k0 of + 1 @ alt1 -> + k1 <- pure 1 + x <- prim_int_add k1 k1 + tail1 x + 2 @ alt2 -> + k2 <- pure 2 + x <- prim_int_add k2 k2 + tail2 x + |] + tailCalls fun `shouldBe` (Just ["tail1", "tail2"]) + + it "Partially tail call function 2" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 1 + case k0 of + 1 @ alt1 -> + k1 <- pure 1 + x <- prim_int_add k1 k1 + y <- tail x + pure y + 2 @ alt2 -> + k2 <- pure 2 + x <- prim_int_add k2 k2 + tail x + |] + tailCalls fun `shouldBe` (Just ["tail"]) + + it "Non-tail call function 1" $ do + let fun = [def| + fun x = + l <- store x + k0 <- pure 3 + y <- tail k0 + pure x + |] + tailCalls fun `shouldBe` Nothing diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/InliningSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/InliningSpec.hs new file mode 100644 index 00000000..c611fe43 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/InliningSpec.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.InliningSpec where + +import Transformations.ExtendedSyntax.Optimising.Inlining + +import qualified Data.Set as Set + +import Test.Hspec +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeCheck +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "base case" $ do + let before = [prog| + grinMain = + k <- pure 0 + x <- funA k + y <- funA k + pure x + + funA i = pure i + |] + let after = [prog| + grinMain = + k <- pure 0 + x <- do + i.0 <- pure k + pure i.0 + y <- do + i.1 <- pure k + pure i.1 + pure x + + funA i = pure i + |] + let inlineSet = Set.fromList ["funA"] + inlining inlineSet (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "no-inline grinMain" $ do + let before = [prog| + grinMain = + k <- pure 0 + x <- pure k + pure x + |] + let after = [prog| + grinMain = + k <- pure 0 + x <- pure k + pure x + |] + lateInlining (inferTypeEnv before) before `sameAs` (after, NoChange) diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/NonSharedEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/NonSharedEliminationSpec.hs new file mode 100644 index 00000000..a845f847 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/NonSharedEliminationSpec.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.NonSharedEliminationSpec where + +import Transformations.ExtendedSyntax.Optimising.NonSharedElimination + +import qualified Data.Set as Set +import qualified Data.Map as Map + +import Test.Hspec +import Test.ExtendedSyntax.Util (loc) +import Test.ExtendedSyntax.Assertions +import Test.ExtendedSyntax.New.Test (testExprContextE) + +import Grin.ExtendedSyntax.Syntax +import AbstractInterpretation.ExtendedSyntax.HeapPointsTo.Result as HPT +import AbstractInterpretation.ExtendedSyntax.Sharing.Result +import Grin.ExtendedSyntax.TH (expr) + + +runTests :: IO () +runTests = hspec spec + +nonSharedElimination' :: SharingResult -> Exp -> Exp +nonSharedElimination' shRes = fst . nonSharedElimination shRes + +-- NOTE: The type environments are partial, because we only need pointer information. + +spec :: Spec +spec = do + testExprContextE $ \ctx -> do + it "simple non-shared" $ do + let before = [expr| + n1 <- pure (COne) + p1 <- store n1 + v1 <- fetch p1 + n2 <- pure (CTwo) + _1 <- update p1 n2 + pure () + |] + let after = [expr| + n1 <- pure (COne) + p1 <- store n1 + v1 <- fetch p1 + n2 <- pure (CTwo) + pure () + |] + + let hptResult = HPTResult + { HPT._memory = mempty + , HPT._register = Map.fromList [ ("p1", loc 0)] + , HPT._function = mempty + } + sharedLocs = mempty + shResult = SharingResult hptResult sharedLocs + + nonSharedElimination' shResult (ctx before) `sameAs` (ctx after) + + it "simple shared" $ do + let before = [expr| + n1 <- pure (COne) + p1 <- store n1 + v1 <- fetch p1 + n2 <- pure (CTwo) + _1 <- update p1 n2 + v2 <- fetch p1 + pure () + |] + let after = [expr| + n1 <- pure (COne) + p1 <- store n1 + v1 <- fetch p1 + n2 <- pure (CTwo) + _1 <- update p1 n2 + v2 <- fetch p1 + pure () + |] + + let hptResult = HPTResult + { HPT._memory = mempty + , HPT._register = Map.fromList [ ("p1", loc 0)] + , HPT._function = mempty + } + sharedLocs = Set.fromList [0] + shResult = SharingResult hptResult sharedLocs + + nonSharedElimination' shResult (ctx before) `sameAs` (ctx after) diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionEliminationSpec.hs new file mode 100644 index 00000000..a6b03d73 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadFunctionEliminationSpec.hs @@ -0,0 +1,108 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadFunctionEliminationSpec where + +import Transformations.Optimising.SimpleDeadFunctionElimination + +import Test.Hspec +import Grin.TH +import Test.Test hiding (newVar) +import Test.Assertions + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "simple" $ do + let before = [prog| + grinMain = + x <- pure 1 + funA x + funB x + + funA a = pure () + funB b = funC b + funC c = pure () + + deadFunA d = pure d + deadFunB e = deadFunA e + |] + let after = [prog| + grinMain = + x <- pure 1 + funA x + funB x + + funA a = pure () + funB b = funC b + funC c = pure () + |] + simpleDeadFunctionElimination before `sameAs` after + + it "reference direction" $ do + let before = [prog| + grinMain = + x <- pure 1 + funA x + + funA b = funB b + funB c = pure () + + deadFunA d = funA d + deadFunB e = deadFunA e + |] + let after = [prog| + grinMain = + x <- pure 1 + funA x + + funA b = funB b + funB c = pure () + |] + simpleDeadFunctionElimination before `sameAs` after + + it "ignore unknown function" $ do + let before = [prog| + grinMain = + x <- pure 1 + funA x + funB x + + deadFunA d = pure d + deadFunB e = deadFunA e + |] + let after = [prog| + grinMain = + x <- pure 1 + funA x + funB x + |] + simpleDeadFunctionElimination before `sameAs` after + + it "dead clique" $ do + let before = [prog| + grinMain = + x <- pure 1 + funA x + + funA b = funB b + funB c = pure () + + deadFunA d = + v1 <- funA d + deadFunB d + + deadFunB e = + v2 <- funA d + deadFunA e + |] + let after = [prog| + grinMain = + x <- pure 1 + funA x + + funA b = funB b + funB c = pure () + |] + simpleDeadFunctionElimination before `sameAs` after diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterEliminationSpec.hs new file mode 100644 index 00000000..a09b0435 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadParameterEliminationSpec.hs @@ -0,0 +1,79 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadParameterEliminationSpec where + +import Transformations.ExtendedSyntax.Optimising.SimpleDeadParameterElimination + +import Test.Hspec + +import Grin.ExtendedSyntax.TH +import Test.ExtendedSyntax.Assertions + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "simple" $ do + let before = [prog| + funA a b = pure b + funB c = + k0 <- pure 1 + funA c k0 + |] + let after = [prog| + funA b = pure b + funB c = + k0 <- pure 1 + funA k0 + |] + simpleDeadParameterElimination before `sameAs` after + + it "Pnode + Fnode ; val - lpat - cpat" $ do + let before = [prog| + funA a b = pure b + funB c = + k0 <- pure 1 + funA c k0 + + eval p = + v <- fetch p + case v of + (FfunB c1) @ alt1 -> funB c1 + (FfunA a1 b1) @ alt2 -> + (FfunA a2 b2) @ _1 <- pure (FfunA a1 b1) + funA a2 b2 + (P2funA) @ alt3 -> + (P2funA) @ _2 <- pure (P2funA) + pure (P2funA) + (P1funA a3) @ alt4 -> + (P1funA a4) @ _3 <- pure (P1funA a3) + pure (P1funA a4) + (P0funA a5 b5) @ alt5 -> + (P0funA a6 b6) @ _4 <- pure (P0funA a5 b5) + pure (P0funA a6 b6) + |] + let after = [prog| + funA b = pure b + funB c = + k0 <- pure 1 + funA k0 + + eval p = + v <- fetch p + case v of + (FfunB c1) @ alt1 -> funB c1 + (FfunA b1) @ alt2 -> + (FfunA b2) @ _1 <- pure (FfunA b1) + funA b2 + (P2funA) @ alt3 -> + (P2funA) @ _2 <- pure (P2funA) + pure (P2funA) + (P1funA) @ alt4 -> + (P1funA) @ _3 <- pure (P1funA) + pure (P1funA) + (P0funA b5) @ alt5 -> + (P0funA b6) @ _4 <- pure (P0funA b5) + pure (P0funA b6) + |] + simpleDeadParameterElimination before `sameAs` after diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableEliminationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableEliminationSpec.hs new file mode 100644 index 00000000..5327decf --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/SimpleDeadVariableEliminationSpec.hs @@ -0,0 +1,328 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.SimpleDeadVariableEliminationSpec where + +import Transformations.ExtendedSyntax.Optimising.SimpleDeadVariableElimination +import Transformations.ExtendedSyntax.EffectMap + +import Test.Hspec + +import Test.ExtendedSyntax.Assertions +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.PrimOpsPrelude +import Grin.ExtendedSyntax.TypeCheck + + +runTests :: IO () +runTests = hspec spec + + +spec :: Spec +spec = do + describe "Bugs" $ do + it "keep blocks" $ do + let before = withPrimPrelude [prog| + grinMain = + fun_main.0 <- pure (P1Main.main.closure.0) + p.1.0 <- pure fun_main.0 + "unboxed.C\"GHC.Prim.Unit#\".0" <- do + result_Main.main1.0.0.0 <- pure (P1Main.main1.closure.0) + apply.unboxed2 $ result_Main.main1.0.0.0 + k0 <- pure 0 + _prim_int_print $ k0 + + apply.unboxed2 p.1.X = + do + (P1Main.main1.closure.0) @ v0 <- pure p.1.X + k1 <- pure 12 + _1 <- _prim_int_print $ k1 + n0 <- pure (F"GHC.Tuple.()") + store n0 + |] + let after = withPrimPrelude [prog| + grinMain = + "unboxed.C\"GHC.Prim.Unit#\".0" <- do + result_Main.main1.0.0.0 <- pure (P1Main.main1.closure.0) + apply.unboxed2 $ result_Main.main1.0.0.0 + k0 <- pure 0 + _prim_int_print $ k0 + + apply.unboxed2 p.1.X = + do + k1 <- pure 12 + _1 <- _prim_int_print $ k1 + n0 <- pure (F"GHC.Tuple.()") + store n0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "do not remove effectful case" $ do + let before = withPrimPrelude [prog| + sideeff s1 = + s2 <- _prim_int_add s1 s1 + _prim_int_print s2 + + grinMain = + k1 <- pure 0 + n0 <- pure (CInt k1) + x <- case n0 of + (CInt x1) @ alt1 -> + _1 <- sideeff x1 + pure 1 -- pure (CInt 1) + (CFloat y1) @ alt2 -> + y2 <- _prim_int_add k1 k1 + pure 2 -- pure (CInt y2) + pure () + |] + let after = withPrimPrelude [prog| + sideeff s1 = + s2 <- _prim_int_add s1 s1 + _prim_int_print s2 + + grinMain = + k1 <- pure 0 + n0 <- pure (CInt k1) + x <- case n0 of + (CInt x1) @ alt1 -> + _1 <- sideeff x1 + pure 1 + (CFloat y1) @ alt2 -> + pure 2 + pure () + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "do not remove effectful case 2" $ do + let before = withPrimPrelude [prog| + grinMain = + k0 <- pure #"str" + y <- pure (CInt k0) + x <- case y of + (CInt x1) @ alt1 -> + _1 <- _prim_string_print x1 + pure 1 -- pure (CInt 1) + pure () + |] + let after = withPrimPrelude [prog| + grinMain = + k0 <- pure #"str" + y <- pure (CInt k0) + x <- case y of + (CInt x1) @ alt1 -> + _1 <- _prim_string_print x1 + pure 1 + pure () + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + + describe "Simple dead variable elimination works for" $ do + it "simple" $ do + let before = [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + p1 <- store n1 + n2 <- pure (CNode p1) + p2 <- store n2 + pure 0 + |] + let after = [prog| + grinMain = + pure 0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "pure case" $ do + let before = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + p1 <- store n1 + n2 <- pure (CNode p1) + p2 <- store n2 + _1 <- _prim_int_print i1 + i2 <- case n1 of + 1 @ alt1 -> pure 2 + 2 @ alt2 -> pure 3 + #default @ alt3 -> pure 4 + pure 0 + |] + let after = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + _1 <- _prim_int_print i1 + pure 0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "effectful case" $ do + let before = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + p1 <- store n1 + n2 <- pure (CNode p1) + p2 <- store n2 + _1 <- _prim_int_print i1 + _2 <- case n1 of + 1 @ alt1 -> pure () + 2 @ alt2 -> _prim_int_print i1 + #default @ alt3 -> pure () + pure 0 + |] + let after = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + _1 <- _prim_int_print i1 + _2 <- case n1 of + 1 @ alt1 -> pure () + 2 @ alt2 -> _prim_int_print i1 + #default @ alt3 -> pure () + pure 0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "nested effectful case" $ do + let before = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + p1 <- store n1 + n2 <- pure (CNode p1) + p2 <- store n2 + _1 <- _prim_int_print i1 + _2 <- case n1 of + 1 @ alt1 -> + i2 <- case n2 of + 0 @ alt11 -> pure 1 + #default @ alt12 -> pure 2 + pure () + 2 @ alt2 -> _prim_int_print i1 + #default @ alt3 -> pure () + pure 0 + |] + let after = withPrimPrelude [prog| + grinMain = + i1 <- pure 1 + n1 <- pure (CNode i1) + _1 <- _prim_int_print i1 + _2 <- case n1 of + 1 @ alt1 -> pure () + 2 @ alt2 -> _prim_int_print i1 + #default @ alt3 -> pure () + pure 0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "node pattern" $ do + let before = [prog| + grinMain = + i1 <- pure 1 + (CNode i2) @ v1 <- pure (CNode i1) + (CNode i3) @ v2 <- pure (CNode i1) + n1 <- pure (CNode i2) + (CNode i4) @ v3 <- pure n1 + pure i1 + |] + let after = [prog| + grinMain = + i1 <- pure 1 + pure i1 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + it "pattern match" $ do + let before = [prog| + grinMain = + i1 <- pure 0 + n1 <- pure (CNode i1) + (CNode i3) @ v1 <- pure n1 + (CNil) @ v2 <- pure (CNil) + (CUnit) @ v3 <- pure (CUnit) + n2 <- pure (CNode i3) + (CNode i4) @ v4 <- pure (CNode i3) + (CNode i5) @ v5 <- pure n2 + (CNode i6) @ v6 <- pure n2 + pure 0 + |] + let after = [prog| + grinMain = + pure 0 + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after + + -- QUESTION: Does this belong here, or to DeadVariableEliminationSpec? + describe "Interprocedural DVE regression tests" $ do + it "not explicitly covered alternatives trigger undefined replacements" $ do + let before = withPrimPrelude [prog| + grinMain = + one <- pure 1 + two <- pure 2 + v0 <- _prim_int_add one one + v1 <- case v0 of + 2 @ alt1 -> + v2 <- _prim_int_lt one two + v3 <- case v2 of + #False @ alt11 -> pure v0 + #True @ alt12 -> pure 1 + case v3 of + 0 @ alt13 -> pure (CGT) + 1 @ alt14 -> pure (CLT) + 1 @ alt2 -> pure (CEQ) + -- If #default is changed to explicit alternatives the undefineds are not introduced. + -- Undefineds are introduced for missing alternatives too. + case v1 of + (CEQ) @ alt3 -> _prim_int_print one + #default @ alt4 -> _prim_int_print two + |] + let after = withPrimPrelude [prog| + grinMain = + one <- pure 1 + two <- pure 2 + v0 <- _prim_int_add one one + v1 <- case v0 of + 2 @ alt1 -> + v2 <- _prim_int_lt one two + v3 <- case v2 of + #False @ alt11 -> pure v0 + #True @ alt12 -> pure 1 + case v3 of + 0 @ alt13 -> pure (CGT) + 1 @ alt14 -> pure (CLT) + 1 @ alt2 -> pure (CEQ) + case v1 of + (CEQ) @ alt3 -> _prim_int_print one + #default @ alt4 -> _prim_int_print two + |] + let tyEnv = inferTypeEnv before + effMap = effectMap (tyEnv, before) + dveExp = simpleDeadVariableElimination effMap before + dveExp `sameAs` after diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisationSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisationSpec.hs new file mode 100644 index 00000000..e2051b34 --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/SparseCaseOptimisationSpec.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisationSpec where + +import Transformations.ExtendedSyntax.Optimising.SparseCaseOptimisation + + +import qualified Data.Map as Map +import qualified Data.Vector as Vector + +import Test.Hspec +import Test.ExtendedSyntax.New.Test hiding (newVar) +import Test.ExtendedSyntax.Assertions + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeEnv + +-- TODO: Replace type env construction with new primitives from Test.ExtendedSyntax.Util + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + testExprContext $ \ctx -> do + it "Figure 4.25" $ do + let teBefore = create $ + (newVar "v" $ T_NodeSet (Map.fromList [(Tag C "Cons", Vector.fromList [T_Int64, T_Location [1]])])) + let before = [expr| + v <- eval l + case v of + (CNil) @ alt1 -> pure 1 + (CCons x xs) @ alt2 -> pure 2 + |] + let after = [expr| + v <- eval l + case v of + (CCons x xs) @ alt2 -> pure 2 + |] + let Right transformed = sparseCaseOptimisation teBefore before + ctx (teBefore, transformed) `sameAs` (ctx (teBefore, after)) + + it "Negative case, full context" $ do + let teBefore = create $ + (newVar "v" $ T_NodeSet (Map.fromList + [ (Tag C "Nil", Vector.fromList []) + , (Tag C "Cons", Vector.fromList [T_Int64, T_Location [1]]) + ])) + let before = [expr| + v <- eval l + case v of + (CNil) @ alt1 -> pure 1 + (CCons x xs) @ alt2 -> pure 2 + |] + let after = [expr| + v <- eval l + case v of + (CNil) @ alt1 -> pure 1 + (CCons x xs) @ alt2 -> pure 2 + |] + let Right transformed = sparseCaseOptimisation teBefore before + ctx (teBefore, transformed) `sameAs` (ctx (teBefore, after)) + + it "default" $ do + let teBefore = create $ + (newVar "v" $ T_NodeSet (Map.fromList [(Tag C "Cons", Vector.fromList [T_Int64, T_Location [1]])])) + let before = [expr| + v <- eval l + case v of + (CNil) @ alt1 -> pure 1 + (CCons x xs) @ alt2 -> pure 2 + #default @ alt3 -> pure 3 + |] + let after = [expr| + v <- eval l + case v of + (CCons x xs) @ alt2 -> pure 2 + |] + let Right transformed = sparseCaseOptimisation teBefore before + ctx (teBefore, transformed) `sameAs` (ctx (teBefore, after)) + + it "negative case with default" $ do + let teBefore = create $ + (newVar "v" $ T_NodeSet (Map.fromList + [ (Tag C "Nil2", Vector.fromList []) + , (Tag C "Cons", Vector.fromList [T_Int64, T_Location [1]]) + ])) + let before = [expr| + v <- eval l + case v of + (CNil) @ alt1 -> pure 1 + (CCons x xs) @ alt2 -> pure 2 + #default @ alt3 -> pure 3 + |] + let after = [expr| + v <- eval l + case v of + (CCons x xs) @ alt2 -> pure 2 + #default @ alt3 -> pure 3 + |] + let Right transformed = sparseCaseOptimisation teBefore before + ctx (teBefore, transformed) `sameAs` (ctx (teBefore, after))