1
1
mirror of https://github.com/anoma/juvix.git synced 2025-01-03 13:03:25 +03:00

Automatically detect and split mutually recursive blocks in let expressions (#1894)

- Closes #1677
This commit is contained in:
janmasrovira 2023-03-17 12:05:55 +01:00 committed by GitHub
parent da44ad6c6b
commit 934a273e2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 294 additions and 126 deletions

View File

@ -1,4 +1,4 @@
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, ExportsTable) where
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, buildDependencyInfoExpr, ExportsTable) where
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
@ -18,7 +18,23 @@ type ExportsTable = HashSet NameId
buildDependencyInfo :: NonEmpty TopModule -> ExportsTable -> NameDependencyInfo
buildDependencyInfo ms tab =
createDependencyInfo graph startNodes
buildDependencyInfoHelper tab (mapM_ goModule ms)
buildDependencyInfoExpr :: Expression -> NameDependencyInfo
buildDependencyInfoExpr = buildDependencyInfoHelper mempty . goExpression Nothing
buildDependencyInfoHelper ::
ExportsTable ->
( Sem
'[ Reader ExportsTable,
State DependencyGraph,
State StartNodes,
State VisitedModules
]
()
) ->
NameDependencyInfo
buildDependencyInfoHelper tbl m = createDependencyInfo graph startNodes
where
startNodes :: StartNodes
graph :: DependencyGraph
@ -27,12 +43,14 @@ buildDependencyInfo ms tab =
evalState (HashSet.empty :: VisitedModules) $
runState HashSet.empty $
execState HashMap.empty $
runReader tab $
mapM_ goModule ms
runReader tbl m
addStartNode :: (Member (State StartNodes) r) => Name -> Sem r ()
addStartNode n = modify (HashSet.insert n)
addEdgeMay :: (Member (State DependencyGraph) r) => Maybe Name -> Name -> Sem r ()
addEdgeMay mn1 n2 = whenJust mn1 $ \n1 -> addEdge n1 n2
addEdge :: (Member (State DependencyGraph) r) => Name -> Name -> Sem r ()
addEdge n1 n2 =
modify
@ -87,7 +105,7 @@ goStatement modName = \case
StatementAxiom ax -> do
checkStartNode (ax ^. axiomName)
addEdge (ax ^. axiomName) modName
goExpression (ax ^. axiomName) (ax ^. axiomType)
goExpression (Just (ax ^. axiomName)) (ax ^. axiomType)
StatementFunction f -> goTopFunctionDef modName f
StatementImport m -> guardNotVisited (m ^. moduleName) (goModule m)
StatementLocalModule m -> goLocalModule modName m
@ -95,8 +113,8 @@ goStatement modName = \case
checkStartNode (i ^. inductiveName)
checkBuiltinInductiveStartNode i
addEdge (i ^. inductiveName) modName
mapM_ (goFunctionParameter (i ^. inductiveName)) (i ^. inductiveParameters)
goExpression (i ^. inductiveName) (i ^. inductiveType)
mapM_ (goFunctionParameter (Just (i ^. inductiveName))) (i ^. inductiveParameters)
goExpression (Just (i ^. inductiveName)) (i ^. inductiveType)
mapM_ (goConstructorDef (i ^. inductiveName)) (i ^. inductiveConstructors)
goTopFunctionDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionDef -> Sem r ()
@ -110,7 +128,7 @@ goFunctionDefHelper ::
Sem r ()
goFunctionDefHelper f = do
checkStartNode (f ^. funDefName)
goExpression (f ^. funDefName) (f ^. funDefTypeSig)
goExpression (Just (f ^. funDefName)) (f ^. funDefTypeSig)
mapM_ (goFunctionClause (f ^. funDefName)) (f ^. funDefClauses)
-- constructors of an inductive type depend on the inductive type, not the other
@ -118,14 +136,14 @@ goFunctionDefHelper f = do
goConstructorDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> InductiveConstructorDef -> Sem r ()
goConstructorDef indName c = do
addEdge (c ^. constructorName) indName
goExpression indName (c ^. constructorType)
goExpression (Just indName) (c ^. constructorType)
goFunctionClause :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionClause -> Sem r ()
goFunctionClause p c = do
mapM_ (goPattern p) (c ^. clausePatterns)
goExpression p (c ^. clauseBody)
mapM_ (goPattern (Just p)) (c ^. clausePatterns)
goExpression (Just p) (c ^. clauseBody)
goPattern :: forall r. (Member (State DependencyGraph) r) => Name -> PatternArg -> Sem r ()
goPattern :: forall r. (Member (State DependencyGraph) r) => Maybe Name -> PatternArg -> Sem r ()
goPattern n p = case p ^. patternArgPattern of
PatternVariable {} -> return ()
PatternWildcard {} -> return ()
@ -134,12 +152,17 @@ goPattern n p = case p ^. patternArgPattern of
where
goApp :: ConstructorApp -> Sem r ()
goApp (ConstructorApp ctr ps) = do
addEdge n (ctr ^. constructorRefName)
addEdgeMay n (ctr ^. constructorRefName)
mapM_ (goPattern n) ps
goExpression :: forall r. (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> Expression -> Sem r ()
goExpression ::
forall r.
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
Expression ->
Sem r ()
goExpression p e = case e of
ExpressionIden i -> addEdge p (idenName i)
ExpressionIden i -> addEdgeMay p (idenName i)
ExpressionUniverse {} -> return ()
ExpressionFunction f -> do
goFunctionParameter p (f ^. funParameter)
@ -177,8 +200,12 @@ goExpression p e = case e of
goLetClause :: LetClause -> Sem r ()
goLetClause = \case
LetFunDef f -> do
addEdge p (f ^. funDefName)
addEdgeMay p (f ^. funDefName)
goFunctionDefHelper f
goFunctionParameter :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionParameter -> Sem r ()
goFunctionParameter ::
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
FunctionParameter ->
Sem r ()
goFunctionParameter p param = goExpression p (param ^. paramType)

View File

@ -311,8 +311,9 @@ goFunctionDef ::
Sem r ()
goFunctionDef ((f, sym), ty) = do
mbody <- case f ^. Internal.funDefBuiltin of
Just b | isIgnoredBuiltin b -> return Nothing
Just _ -> Just <$> runReader initIndexTable (mkFunBody ty f)
Just b
| isIgnoredBuiltin b -> return Nothing
| otherwise -> Just <$> runReader initIndexTable (mkFunBody ty f)
Nothing -> Just <$> runReader initIndexTable (mkFunBody ty f)
forM_ mbody (registerIdentNode sym)
forM_ mbody setIdentArgsInfo'
@ -461,35 +462,33 @@ goLet ::
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable] r) =>
Internal.Let ->
Sem r Node
goLet l = do
vars <- asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
let bs :: [Name]
bs = map (\(Internal.LetFunDef Internal.FunctionDef {..}) -> _funDefName) (toList $ l ^. Internal.letClauses)
(vars', varsNum') =
foldl'
( \(vs, k) name ->
(HashMap.insert (name ^. nameId) k vs, k + 1)
)
(vars, varsNum)
bs
(defs, value) <- do
values <-
mapM
( \(Internal.LetFunDef f) -> do
funTy <- goType (f ^. Internal.funDefType)
funBody <- local (set indexTableVars vars' . set indexTableVarsNum varsNum') (mkFunBody funTy f)
return (funTy, funBody)
)
(l ^. Internal.letClauses)
lbody <-
local
(set indexTableVars vars' . set indexTableVarsNum varsNum')
(goExpression (l ^. Internal.letExpression))
return (values, lbody)
return $ mkLetRec' defs value
goLet l = goClauses (toList (l ^. Internal.letClauses))
where
goClauses :: [Internal.LetClause] -> Sem r Node
goClauses = \case
[] -> goExpression (l ^. Internal.letExpression)
c : cs -> case c of
Internal.LetFunDef f -> goNonRecFun f
Internal.LetMutualBlock m -> goMutual m
where
goNonRecFun :: Internal.FunctionDef -> Sem r Node
goNonRecFun f =
do
funTy <- goType (f ^. Internal.funDefType)
funBody <- mkFunBody funTy f
rest <- localAddName (f ^. Internal.funDefName) (goClauses cs)
return $ mkLet' funTy funBody rest
goMutual :: Internal.MutualBlock -> Sem r Node
goMutual (Internal.MutualBlock funs) = do
let lfuns = toList funs
names = map (^. Internal.funDefName) lfuns
tys = map (^. Internal.funDefType) lfuns
tys' <- mapM goType tys
localAddNames names $ do
vals' <- sequence [mkFunBody ty f | (ty, f) <- zipExact tys' lfuns]
let items = nonEmpty' (zip tys' vals')
rest <- goClauses cs
return (mkLetRec' items rest)
goAxiomInductive ::
forall r.

View File

@ -14,17 +14,24 @@ makeLenses ''IndexTable
initIndexTable :: IndexTable
initIndexTable = IndexTable 0 mempty
localAddName :: forall r a. (Member (Reader IndexTable) r) => Name -> Sem r a -> Sem r a
localAddName n s = do
localAddName :: Member (Reader IndexTable) r => Name -> Sem r a -> Sem r a
localAddName n = localAddNames [n]
localAddNames :: forall r a. (Member (Reader IndexTable) r) => [Name] -> Sem r a -> Sem r a
localAddNames names s = do
updateFn <- update
local updateFn s
where
len :: Int = length names
insertMany :: [(NameId, Index)] -> HashMap NameId Index -> HashMap NameId Index
insertMany l t = foldl' (\m (k, v) -> HashMap.insert k v m) t l
update :: Sem r (IndexTable -> IndexTable)
update = do
idx <- asks (^. indexTableVarsNum)
let newElems = zip (map (^. nameId) names) [idx ..]
return
( over indexTableVars (HashMap.insert (n ^. nameId) idx)
. over indexTableVarsNum (+ 1)
( over indexTableVars (insertMany newElems)
. over indexTableVarsNum (+ len)
)
underBinders :: Members '[Reader IndexTable] r => Int -> Sem r a -> Sem r a

View File

@ -73,9 +73,25 @@ extendWithReplExpression e =
over
infoFunctions
( HashMap.union
(HashMap.fromList [(f ^. funDefName, FunctionInfo f) | LetFunDef f <- universeBi e])
( HashMap.fromList
[ (f ^. funDefName, FunctionInfo f)
| f <- letFunctionDefs e
]
)
)
letFunctionDefs :: Data from => from -> [FunctionDef]
letFunctionDefs e =
concat
[ concatMap (toList . flattenClause) _letClauses
| Let {..} <- universeBi e
]
where
flattenClause :: LetClause -> NonEmpty FunctionDef
flattenClause = \case
LetFunDef f -> pure f
LetMutualBlock (MutualBlock fs) -> fs
-- | moduleName ↦ infoTable
type Cache = HashMap Name InfoTable
@ -117,7 +133,7 @@ buildTable1' m = do
]
<> [ (f ^. funDefName, FunctionInfo f)
| s <- filter (not . isInclude) ss,
LetFunDef f <- universeBi s
f <- letFunctionDefs s
]
where
isInclude :: Statement -> Bool

View File

@ -73,9 +73,14 @@ instance HasExpressions Case where
where
_caseParens = l ^. caseParens
instance HasExpressions MutualBlock where
leafExpressions f (MutualBlock defs) =
MutualBlock <$> traverse (leafExpressions f) defs
instance HasExpressions LetClause where
leafExpressions f = \case
LetFunDef d -> LetFunDef <$> leafExpressions f d
LetMutualBlock b -> LetMutualBlock <$> leafExpressions f b
instance HasExpressions Let where
leafExpressions f l = do

View File

@ -46,7 +46,9 @@ data Statement
newtype MutualBlock = MutualBlock
{ _mutualFunctions :: NonEmpty FunctionDef
}
deriving stock (Data)
deriving stock (Eq, Generic, Data)
instance Hashable MutualBlock
data AxiomDef = AxiomDef
{ _axiomName :: AxiomName,
@ -98,8 +100,10 @@ data TypedExpression = TypedExpression
_typedExpression :: Expression
}
newtype LetClause
= LetFunDef FunctionDef
data LetClause
= -- | Non-recursive let definition
LetFunDef FunctionDef
| LetMutualBlock MutualBlock
deriving stock (Eq, Generic, Data)
instance Hashable LetClause
@ -367,9 +371,13 @@ instance HasLoc FunctionClause where
instance HasLoc FunctionDef where
getLoc f = getLoc (f ^. funDefName) <> getLocSpan (f ^. funDefClauses)
instance HasLoc MutualBlock where
getLoc (MutualBlock defs) = getLocSpan defs
instance HasLoc LetClause where
getLoc = \case
LetFunDef f -> getLoc f
LetMutualBlock f -> getLoc f
instance HasLoc Let where
getLoc l = getLocSpan (l ^. letClauses) <> getLoc (l ^. letExpression)

View File

@ -91,8 +91,17 @@ instance PrettyCode Let where
return $ kwLet <+> letClauses' <+> kwIn <+> letExpression'
instance PrettyCode LetClause where
ppCode :: forall r. Member (Reader Options) r => LetClause -> Sem r (Doc Ann)
ppCode = \case
LetFunDef f -> ppCode f
LetMutualBlock b -> ppMutual b
where
ppMutual :: MutualBlock -> Sem r (Doc Ann)
ppMutual m@(MutualBlock b)
| [_] <- toList b = ppCode b
| otherwise = do
b' <- ppCode m
return (kwMutual <+> braces (line <> indent' b' <> line))
ppPipeBlock :: (PrettyCode a, Members '[Reader Options] r, Traversable t) => t a -> Sem r (Doc Ann)
ppPipeBlock items = vsep <$> mapM (fmap (kwPipe <+>) . ppCode) items

View File

@ -8,16 +8,15 @@ module Juvix.Compiler.Internal.Translation.FromAbstract
)
where
import Data.Graph
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Data.NameDependencyInfo qualified as Abstract
import Juvix.Compiler.Abstract.Data.NameDependencyInfo
import Juvix.Compiler.Abstract.Extra.DependencyBuilder
import Juvix.Compiler.Abstract.Extra.DependencyBuilder qualified as Abstract
import Juvix.Compiler.Abstract.Language qualified as Abstract
import Juvix.Compiler.Abstract.Translation.FromConcrete.Data.Context qualified as Abstract
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Translation.FromAbstract.Analysis.Termination hiding (Graph)
import Juvix.Compiler.Internal.Translation.FromAbstract.Analysis.Termination
import Juvix.Compiler.Internal.Translation.FromAbstract.Data.Context
import Juvix.Compiler.Pipeline.EntryPoint qualified as E
import Juvix.Prelude
@ -69,10 +68,14 @@ fromAbstract abstractResults = do
abstractResults
^. Abstract.abstractResultEntryPoint
. E.entryPointNoTermination
depInfo :: NameDependencyInfo
depInfo = buildDependencyInfo (abstractResults ^. Abstract.resultModules) (abstractResults ^. Abstract.resultExports)
fromAbstractExpression :: (Members '[NameIdGen] r) => Abstract.Expression -> Sem r Expression
fromAbstractExpression = goExpression
fromAbstractExpression :: Members '[NameIdGen] r => Abstract.Expression -> Sem r Expression
fromAbstractExpression e = runReader depInfo (goExpression e)
where
depInfo :: NameDependencyInfo
depInfo = buildDependencyInfoExpr e
goModule ::
(Members '[Reader ExportsTable, State TranslationState, NameIdGen] r) =>
@ -80,42 +83,45 @@ goModule ::
Sem r Module
goModule m = do
expTbl <- ask
let mutualBlocks :: [NonEmpty Abstract.FunctionDef]
mutualBlocks = buildMutualBlocks expTbl
_moduleBody' <- goModuleBody mutualBlocks (m ^. Abstract.moduleBody)
examples' <- mapM goExample (m ^. Abstract.moduleExamples)
return
Module
{ _moduleName = m ^. Abstract.moduleName,
_moduleExamples = examples',
_moduleBody = _moduleBody'
}
let depInfo :: NameDependencyInfo
depInfo = Abstract.buildDependencyInfo (pure m) expTbl
runReader depInfo $ do
mutualBlocks :: [SCC Abstract.FunctionDef] <- buildMutualBlocks moduleFunctionDefs
_moduleBody' <- goModuleBody (map flattenSCC mutualBlocks) (m ^. Abstract.moduleBody)
examples' <- mapM goExample (m ^. Abstract.moduleExamples)
return
Module
{ _moduleName = m ^. Abstract.moduleName,
_moduleExamples = examples',
_moduleBody = _moduleBody'
}
where
moduleFunctionDefs :: [Abstract.FunctionDef]
moduleFunctionDefs = [d | Abstract.StatementFunction d <- m ^. Abstract.moduleBody . Abstract.moduleStatements]
buildMutualBlocks :: Members '[Reader NameDependencyInfo] r => [Abstract.FunctionDef] -> Sem r [SCC Abstract.FunctionDef]
buildMutualBlocks defs = do
depInfo <- ask
let scomponents :: [SCC Abstract.Name] = buildSCCs depInfo
return (mapMaybe helper scomponents)
where
funsByName :: HashMap Abstract.FunctionName Abstract.FunctionDef
funsByName =
HashMap.fromList
[ (d ^. Abstract.funDefName, d)
| Abstract.StatementFunction d <- m ^. Abstract.moduleBody . Abstract.moduleStatements
]
funsByName = HashMap.fromList [(d ^. Abstract.funDefName, d) | d <- defs]
getFun :: Abstract.FunctionName -> Maybe Abstract.FunctionDef
getFun n = funsByName ^. at n
buildMutualBlocks :: Abstract.ExportsTable -> [NonEmpty Abstract.FunctionDef]
buildMutualBlocks expTbl = mapMaybe (nonEmpty . mapMaybe getFun . toList . fromNonEmptyTree) scomponents
helper :: SCC Abstract.Name -> Maybe (SCC Abstract.FunctionDef)
helper = nonEmptySCC . fmap getFun
where
fromNonEmptyTree :: Tree a -> NonEmpty a
fromNonEmptyTree = fromJust . nonEmpty . toList
depInfo :: Abstract.NameDependencyInfo
depInfo = Abstract.buildDependencyInfo (pure m) expTbl
graph :: Graph
graph = Abstract.buildDependencyInfo (pure m) expTbl ^. Abstract.depInfoGraph
scomponents :: [Tree Abstract.Name]
scomponents = fmap (Abstract.nameFromVertex depInfo) <$> scc graph
nonEmptySCC :: SCC (Maybe a) -> Maybe (SCC a)
nonEmptySCC = \case
AcyclicSCC a -> AcyclicSCC <$> a
CyclicSCC p -> CyclicSCC . toList <$> nonEmpty (catMaybes p)
unsupported :: Text -> a
unsupported thing = error ("Abstract to Internal: Not yet supported: " <> thing)
goModuleBody ::
(Members '[Reader ExportsTable, State TranslationState, NameIdGen] r) =>
(Members '[Reader ExportsTable, Reader NameDependencyInfo, State TranslationState, NameIdGen] r) =>
[NonEmpty Abstract.FunctionDef] ->
Abstract.ModuleBody ->
Sem r ModuleBody
@ -143,7 +149,7 @@ goImport m = do
)
goStatement ::
(Members '[Reader ExportsTable, State TranslationState, NameIdGen] r) =>
(Members '[Reader ExportsTable, State TranslationState, NameIdGen, Reader NameDependencyInfo] r) =>
Abstract.Statement ->
Sem r (Maybe Statement)
goStatement = \case
@ -198,7 +204,7 @@ goFunction (Abstract.Function l r) = do
r' <- goType r
return (Function l' r')
goFunctionDef :: (Members '[NameIdGen] r) => Abstract.FunctionDef -> Sem r FunctionDef
goFunctionDef :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.FunctionDef -> Sem r FunctionDef
goFunctionDef f = do
_funDefClauses' <- mapM (goFunctionClause _funDefName') (f ^. Abstract.funDefClauses)
_funDefType' <- goType (f ^. Abstract.funDefTypeSig)
@ -215,7 +221,7 @@ goFunctionDef f = do
_funDefName' :: Name
_funDefName' = f ^. Abstract.funDefName
goExample :: (Members '[NameIdGen] r) => Abstract.Example -> Sem r Example
goExample :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Example -> Sem r Example
goExample e = do
e' <- goExpression (e ^. Abstract.exampleExpression)
return
@ -224,7 +230,7 @@ goExample e = do
_exampleId = e ^. Abstract.exampleId
}
goFunctionClause :: (Members '[NameIdGen] r) => Name -> Abstract.FunctionClause -> Sem r FunctionClause
goFunctionClause :: Members '[NameIdGen, Reader NameDependencyInfo] r => Name -> Abstract.FunctionClause -> Sem r FunctionClause
goFunctionClause n c = do
_clauseBody' <- goExpression (c ^. Abstract.clauseBody)
_clausePatterns' <- mapM goPatternArg (c ^. Abstract.clausePatterns)
@ -287,7 +293,7 @@ goType e = case e of
Abstract.ExpressionLet {} -> unsupported "let in types"
Abstract.ExpressionCase {} -> unsupported "case in types"
goLambda :: forall r. (Members '[NameIdGen] r) => Abstract.Lambda -> Sem r Lambda
goLambda :: forall r. Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Lambda -> Sem r Lambda
goLambda (Abstract.Lambda cl') = do
_lambdaClauses <- mapM goClause cl'
let _lambdaType :: Maybe Expression = Nothing
@ -304,7 +310,7 @@ goLambda (Abstract.Lambda cl') = do
Explicit -> p
Implicit -> unsupported "implicit patterns in lambda"
goApplication :: (Members '[NameIdGen] r) => Abstract.Application -> Sem r Application
goApplication :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Application -> Sem r Application
goApplication (Abstract.Application f x i) = do
f' <- goExpression f
x' <- goExpression x
@ -318,7 +324,7 @@ goIden i = case i of
Abstract.IdenAxiom a -> IdenAxiom (a ^. Abstract.axiomRefName)
Abstract.IdenInductive a -> IdenInductive (a ^. Abstract.inductiveRefName)
goExpressionFunction :: forall r. (Members '[NameIdGen] r) => Abstract.Function -> Sem r Function
goExpressionFunction :: forall r. Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Function -> Sem r Function
goExpressionFunction f = do
l' <- goParam (f ^. Abstract.funParameter)
r' <- goExpression (f ^. Abstract.funReturn)
@ -329,7 +335,7 @@ goExpressionFunction f = do
ty' <- goExpression (p ^. Abstract.paramType)
return (FunctionParameter (p ^. Abstract.paramName) (p ^. Abstract.paramImplicit) ty')
goExpression :: (Members '[NameIdGen] r) => Abstract.Expression -> Sem r Expression
goExpression :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Expression -> Sem r Expression
goExpression e = case e of
Abstract.ExpressionIden i -> return (ExpressionIden (goIden i))
Abstract.ExpressionUniverse u -> return (ExpressionUniverse (goUniverse u))
@ -341,7 +347,7 @@ goExpression e = case e of
Abstract.ExpressionLet l -> ExpressionLet <$> goLet l
Abstract.ExpressionCase c -> ExpressionCase <$> goCase c
goCase :: Members '[NameIdGen] r => Abstract.Case -> Sem r Case
goCase :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.Case -> Sem r Case
goCase c = do
_caseExpression <- goExpression (c ^. Abstract.caseExpression)
_caseBranches <- mapM goCaseBranch (c ^. Abstract.caseBranches)
@ -350,21 +356,25 @@ goCase c = do
_caseExpressionWholeType :: Maybe Expression = Nothing
return Case {..}
goCaseBranch :: Members '[NameIdGen] r => Abstract.CaseBranch -> Sem r CaseBranch
goCaseBranch :: Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.CaseBranch -> Sem r CaseBranch
goCaseBranch b = do
_caseBranchPattern <- goPatternArg (b ^. Abstract.caseBranchPattern)
_caseBranchExpression <- goExpression (b ^. Abstract.caseBranchExpression)
return CaseBranch {..}
goLetClause :: (Members '[NameIdGen] r) => Abstract.LetClause -> Sem r LetClause
goLetClause = \case
Abstract.LetFunDef f -> LetFunDef <$> goFunctionDef f
goLet :: (Members '[NameIdGen] r) => Abstract.Let -> Sem r Let
goLet :: forall r. (Members '[NameIdGen, Reader NameDependencyInfo] r) => Abstract.Let -> Sem r Let
goLet l = do
_letExpression <- goExpression (l ^. Abstract.letExpression)
_letClauses <- mapM goLetClause (l ^. Abstract.letClauses)
mutualBlocks <- buildMutualBlocks funDefs
_letClauses <- nonEmpty' <$> mapM goLetBlock mutualBlocks
return Let {..}
where
funDefs :: [Abstract.FunctionDef]
funDefs = [f | Abstract.LetFunDef f <- toList (l ^. Abstract.letClauses)]
goLetBlock :: SCC Abstract.FunctionDef -> Sem r LetClause
goLetBlock = \case
AcyclicSCC f -> LetFunDef <$> goFunctionDef f
CyclicSCC m -> LetMutualBlock . MutualBlock <$> mapM goFunctionDef (nonEmpty' m)
goInductiveParameter :: Abstract.FunctionParameter -> Sem r InductiveParameter
goInductiveParameter f =
@ -378,7 +388,7 @@ goInductiveParameter f =
(Just {}, _) -> unsupported "only type variables of small types are allowed"
(Nothing, _) -> unsupported "unnamed inductive parameters"
goInductiveDef :: forall r. (Members '[NameIdGen] r) => Abstract.InductiveDef -> Sem r InductiveDef
goInductiveDef :: forall r. Members '[NameIdGen, Reader NameDependencyInfo] r => Abstract.InductiveDef -> Sem r InductiveDef
goInductiveDef i
| not (isSmallType (i ^. Abstract.inductiveType)) = unsupported "inductive indices"
| otherwise = do

View File

@ -11,7 +11,7 @@ import Juvix.Compiler.Internal.Translation.FromAbstract.Analysis.Termination.Dat
import Juvix.Prelude
import Prettyprinter qualified as PP
type Graph = HashMap (FunctionName, FunctionName) Edge
type EdgeMap = HashMap (FunctionName, FunctionName) Edge
data Edge = Edge
{ _edgeFrom :: FunctionName,
@ -19,7 +19,7 @@ data Edge = Edge
_edgeMatrices :: HashSet CallMatrix
}
newtype CompleteCallGraph = CompleteCallGraph Graph
newtype CompleteCallGraph = CompleteCallGraph EdgeMap
data ReflexiveEdge = ReflexiveEdge
{ _reflexiveEdgeFun :: FunctionName,

View File

@ -9,7 +9,7 @@ import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Internal.Translation.FromAbstract.Analysis.Termination.Data
import Juvix.Prelude
fromEdgeList :: [Edge] -> Graph
fromEdgeList :: [Edge] -> EdgeMap
fromEdgeList l = HashMap.fromList [((e ^. edgeFrom, e ^. edgeTo), e) | e <- l]
composeEdge :: Edge -> Edge -> Maybe Edge
@ -22,7 +22,7 @@ composeEdge a b = do
_edgeMatrices = multiplyMany (a ^. edgeMatrices) (b ^. edgeMatrices)
}
edgesCompose :: Graph -> Graph -> Graph
edgesCompose :: EdgeMap -> EdgeMap -> EdgeMap
edgesCompose g h =
fromEdgeList
(catMaybes [composeEdge u v | u <- toList g, v <- toList h])
@ -37,10 +37,10 @@ edgeUnion a b
(HashSet.union (a ^. edgeMatrices) (b ^. edgeMatrices))
| otherwise = impossible
edgesUnion :: Graph -> Graph -> Graph
edgesUnion :: EdgeMap -> EdgeMap -> EdgeMap
edgesUnion = HashMap.unionWith edgeUnion
edgesCount :: Graph -> Int
edgesCount :: EdgeMap -> Int
edgesCount es = sum [HashSet.size (e ^. edgeMatrices) | e <- toList es]
multiply :: CallMatrix -> CallMatrix -> CallMatrix
@ -77,10 +77,10 @@ unsafeFilterGraph funNames (CompleteCallGraph g) =
completeCallGraph :: CallMap -> CompleteCallGraph
completeCallGraph CallMap {..} = CompleteCallGraph (go startingEdges)
where
startingEdges :: Graph
startingEdges :: EdgeMap
startingEdges = foldr insertCall mempty allCalls
where
insertCall :: Call -> Graph -> Graph
insertCall :: Call -> EdgeMap -> EdgeMap
insertCall Call {..} = HashMap.alter (Just . aux) (_callFrom, _callTo)
where
aux :: Maybe Edge -> Edge
@ -96,14 +96,14 @@ completeCallGraph CallMap {..} = CompleteCallGraph (go startingEdges)
funCall <- funCalls
]
go :: Graph -> Graph
go :: EdgeMap -> EdgeMap
go g
| edgesCount g == edgesCount g' = g
| otherwise = go g'
where
g' = step g
step :: Graph -> Graph
step :: EdgeMap -> EdgeMap
step s = edgesUnion (edgesCompose s startingEdges) s
reflexiveEdges :: CompleteCallGraph -> [ReflexiveEdge]

View File

@ -377,6 +377,7 @@ checkLet ari l = do
checkLetClause :: LetClause -> Sem r LetClause
checkLetClause = \case
LetFunDef f -> LetFunDef <$> checkFunctionDef f
LetMutualBlock f -> LetMutualBlock <$> checkMutualBlock f
checkLambda ::
forall r.

View File

@ -89,6 +89,10 @@ checkStrictlyPositiveOccurrences ty ctorName name recLimit ref =
helperLetClause :: LetClause -> Sem r ()
helperLetClause = \case
LetFunDef f -> helperFunctionDef f
LetMutualBlock b -> helperMutualBlock b
helperMutualBlock :: MutualBlock -> Sem r ()
helperMutualBlock b = mapM_ helperFunctionDef (b ^. mutualFunctions)
helperFunctionDef :: FunctionDef -> Sem r ()
helperFunctionDef d = do

View File

@ -57,7 +57,7 @@ checkStatement ::
Statement ->
Sem r Statement
checkStatement s = case s of
StatementFunction funs -> StatementFunction <$> runReader emptyLocalVars (checkMutualBlock funs)
StatementFunction funs -> StatementFunction <$> runReader emptyLocalVars (checkTopMutualBlock funs)
StatementInductive ind -> StatementInductive <$> checkInductiveDef ind
StatementInclude i -> StatementInclude <$> checkInclude i
StatementAxiom ax -> do
@ -125,11 +125,11 @@ checkInductiveDef InductiveDef {..} = runInferenceDef $ do
withEmptyVars :: Sem (Reader LocalVars : r) a -> Sem r a
withEmptyVars = runReader emptyLocalVars
checkMutualBlock ::
checkTopMutualBlock ::
(Members '[Reader LocalVars, Reader InfoTable, Error TypeCheckerError, NameIdGen, State TypesTable, State FunctionsTable, Output Example, Builtins] r) =>
MutualBlock ->
Sem r MutualBlock
checkMutualBlock (MutualBlock ds) = MutualBlock <$> runInferenceDefs (mapM checkFunctionDef ds)
checkTopMutualBlock (MutualBlock ds) = MutualBlock <$> runInferenceDefs (mapM checkFunctionDef ds)
checkFunctionDef ::
(Members '[Reader LocalVars, Reader InfoTable, Error TypeCheckerError, NameIdGen, State TypesTable, State FunctionsTable, Output Example, Builtins, Inference] r) =>
@ -385,7 +385,7 @@ checkPattern = go
indName = IdenInductive (info ^. constructorInfoInductive)
loc = getLoc a
paramHoles <- map ExpressionHole <$> replicateM numIndParams (freshHole loc)
let patternTy = foldApplication (ExpressionIden indName) (zip (repeat Explicit) paramHoles)
let patternTy = foldApplication (ExpressionIden indName) (map (Explicit,) paramHoles)
whenJustM
(matchTypes patternTy (ExpressionHole hole))
err
@ -524,10 +524,13 @@ inferExpression' hint e = case e of
}
}
-- what about mutually recursive lets?
goLetClause :: LetClause -> Sem r LetClause
goLetClause = \case
LetFunDef f -> LetFunDef <$> checkFunctionDef f
LetMutualBlock b -> LetMutualBlock <$> goMutualLet b
where
goMutualLet :: MutualBlock -> Sem r MutualBlock
goMutualLet (MutualBlock fs) = MutualBlock <$> mapM checkFunctionDef fs
goHole :: Hole -> Sem r TypedExpression
goHole h = do

View File

@ -56,6 +56,9 @@ primitive = annotate (AnnKind KNameAxiom) . pretty
keyword :: Text -> Doc Ann
keyword = annotate AnnKeyword . pretty
kwMutual :: Doc Ann
kwMutual = keyword Str.mutual
kwLambda :: Doc Ann
kwLambda = keyword Str.lambdaUnicode

View File

@ -1,6 +1,5 @@
module Juvix.Data.DependencyInfo where
import Data.Graph (Graph, Vertex)
import Data.Graph qualified as Graph
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
@ -13,6 +12,7 @@ import Juvix.Prelude.Base
data DependencyInfo n = DependencyInfo
{ _depInfoGraph :: Graph,
_depInfoNodeFromVertex :: Vertex -> (n, HashSet n),
_depInfoEdgeList :: [(n, n, [n])],
_depInfoVertexFromName :: n -> Maybe Vertex,
_depInfoReachable :: HashSet n,
_depInfoTopSort :: [n]
@ -25,6 +25,7 @@ createDependencyInfo edges startNames =
DependencyInfo
{ _depInfoGraph = graph,
_depInfoNodeFromVertex = \v -> let (_, x, y) = nodeFromVertex v in (x, HashSet.fromList y),
_depInfoEdgeList = edgeList,
_depInfoVertexFromName = vertexFromName,
_depInfoReachable = reachableNames,
_depInfoTopSort = topSortedNames
@ -33,9 +34,9 @@ createDependencyInfo edges startNames =
graph :: Graph
nodeFromVertex :: Vertex -> (n, n, [n])
vertexFromName :: n -> Maybe Vertex
(graph, nodeFromVertex, vertexFromName) =
Graph.graphFromEdges $
map (\(x, y) -> (x, x, HashSet.toList y)) (HashMap.toList edges)
(graph, nodeFromVertex, vertexFromName) = Graph.graphFromEdges edgeList
edgeList :: [(n, n, [n])]
edgeList = map (\(x, y) -> (x, x, HashSet.toList y)) (HashMap.toList edges)
reachableNames :: HashSet n
reachableNames =
HashSet.fromList $
@ -51,3 +52,6 @@ nameFromVertex depInfo v = fst $ (depInfo ^. depInfoNodeFromVertex) v
isReachable :: (Hashable n) => DependencyInfo n -> n -> Bool
isReachable depInfo n = HashSet.member n (depInfo ^. depInfoReachable)
buildSCCs :: Ord n => DependencyInfo n -> [SCC n]
buildSCCs = Graph.stronglyConnComp . (^. depInfoEdgeList)

View File

@ -210,3 +210,6 @@ kwVoid = asciiKw Str.void
kwDollar :: Keyword
kwDollar = asciiKw Str.dollar
kwMutual :: Keyword
kwMutual = asciiKw Str.mutual

View File

@ -311,6 +311,9 @@ mod = "%"
dollar :: (IsString s) => s
dollar = "$"
mutual :: (IsString s) => s
mutual = "mutual"
if_ :: (IsString s) => s
if_ = "if"

View File

@ -7,6 +7,7 @@
module Juvix.Prelude.Base
( module Juvix.Prelude.Base,
module Control.Applicative,
module Data.Graph,
module Data.Map.Strict,
module Data.Set,
module Data.IntMap.Strict,
@ -95,6 +96,7 @@ import Data.Eq
import Data.Foldable hiding (minimum, minimumBy)
import Data.Function
import Data.Functor
import Data.Graph (Graph, SCC (..), Vertex, stronglyConnComp)
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet (HashSet)
@ -441,3 +443,9 @@ ensureFile f =
unlessM
(Path.doesFileExist f)
(throwM (mkIOError doesNotExistErrorType "" Nothing (Just (toFilePath f))))
-- Ideally `CyclicSCC`'s argument' would have type `NonEmpty a` instead of `[a]`
flattenSCC :: SCC a -> NonEmpty a
flattenSCC = \case
AcyclicSCC a -> pure a
CyclicSCC as -> nonEmpty' as

View File

@ -241,5 +241,10 @@ tests =
"Simple case expression"
$(mkRelDir ".")
$(mkRelFile "test038.juvix")
$(mkRelFile "out/test038.out")
$(mkRelFile "out/test038.out"),
posTest
"Mutually recursive let expression"
$(mkRelDir ".")
$(mkRelFile "test039.juvix")
$(mkRelFile "out/test039.out")
]

View File

@ -201,7 +201,11 @@ tests =
posTest
"Type synonym inside let"
$(mkRelDir "issue1879")
$(mkRelFile "LetSynonym.juvix")
$(mkRelFile "LetSynonym.juvix"),
posTest
"Mutual inference inside let"
$(mkRelDir ".")
$(mkRelFile "MutualLet.juvix")
]
<> [ compilationTest t | t <- Compilation.tests, t ^. Compilation.name /= "Self-application"
]

View File

@ -0,0 +1,2 @@
false
true

View File

@ -0,0 +1,22 @@
-- Mutually recursive let expressions
module test039;
open import Stdlib.Prelude;
main : IO;
main :=
let
Ty : Type;
Ty := Nat;
odd : _;
even : _;
unused : _;
odd zero := false;
odd (suc n) := not (even n);
unused := 123;
even zero := true;
even (suc n) := not (odd n);
plusOne : Ty → Ty;
plusOne n := n + 1;
in printBoolLn (odd (plusOne 13))
>> printBoolLn (even (plusOne 12));

View File

@ -0,0 +1,15 @@
module MutualLet;
open import Stdlib.Data.Nat;
open import Stdlib.Data.Bool;
main : _;
main :=
let
odd : _;
even : _;
odd zero := false;
odd (suc n) := not (even n);
even zero := true;
even (suc n) := not (odd n);
in even 5;

View File

@ -118,6 +118,16 @@ tests:
Nat
exit-status: 0
- name: eval-mutually-recursive-let-expression
command:
- juvix
- repl
stdin: "let even : Nat → Bool; odd : _; odd zero := false; odd (suc n) := not (even n); even zero := true; even (suc n) := not (odd n) in even 10"
stdout:
contains:
"true"
exit-status: 0
- name: eval-let-expression
command:
- juvix