mirror of
https://github.com/anoma/juvix.git
synced 2025-01-08 16:51:53 +03:00
Detect termination for nested local definitions (#3169)
* Closes #3147 When we call a function that is currently being defined (there may be several such due to nested local definitions), we add a reflexive edge in the call map instead of adding an edge from the most nested definition. For example, for ```juvix go {A B} (f : A -> B) : List A -> List B | nil := nil | (elem :: next) := let var1 := f elem; var2 := go f next; in var1 :: var2; ``` we add an edge from `go` to the recursive call `go f next`, instead of adding an edge from `var2` to `go f next` as before. This makes the above type-check. The following still doesn't type-check, because `next'` is not a subpattern of the clause pattern of `go`. But this is a less pressing problem. ```juvix go {A B} (f : A -> B) : List A -> List B | nil := nil | (elem :: next) := let var1 := f elem; var2 (next' : List A) : List B := go f next'; in myCons var1 (var2 next); ```
This commit is contained in:
parent
1d7bf1f25b
commit
9f25ffde16
@ -7,10 +7,11 @@ import Data.HashMap.Strict qualified as HashMap
|
||||
import Juvix.Compiler.Internal.Extra
|
||||
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
|
||||
import Juvix.Prelude
|
||||
import Safe (headMay)
|
||||
|
||||
viewCall ::
|
||||
forall r.
|
||||
(Members '[Reader SizeInfo] r) =>
|
||||
(Members '[Reader SizeInfoMap] r) =>
|
||||
Expression ->
|
||||
Sem r (Maybe FunCall)
|
||||
viewCall = \case
|
||||
@ -19,12 +20,15 @@ viewCall = \case
|
||||
ExpressionApplication (Application f x impl)
|
||||
| isImplicitOrInstance impl -> viewCall f -- implicit arguments are ignored
|
||||
| otherwise -> do
|
||||
c <- viewCall f
|
||||
x' <- callArg
|
||||
return $ over callArgs (`snoc` x') <$> c
|
||||
mc <- viewCall f
|
||||
case mc of
|
||||
Just c -> do
|
||||
x' <- callArg (c ^. callRef)
|
||||
return $ Just $ over callArgs (`snoc` x') c
|
||||
Nothing -> return Nothing
|
||||
where
|
||||
callArg :: Sem r (CallRow, Expression)
|
||||
callArg = do
|
||||
callArg :: FunctionRef -> Sem r (CallRow, Expression)
|
||||
callArg fref = do
|
||||
lt <- (^. callRow) <$> lessThan
|
||||
eq <- (^. callRow) <$> equalTo
|
||||
return (CallRow (lt `mplus` eq), x)
|
||||
@ -33,7 +37,7 @@ viewCall = \case
|
||||
lessThan = case viewExpressionAsPattern x of
|
||||
Nothing -> return (CallRow Nothing)
|
||||
Just x' -> do
|
||||
s <- asks (findIndex (elem x') . (^. sizeSmaller))
|
||||
s <- asks (findIndex (elem x') . (^. sizeSmaller) . findSizeInfo)
|
||||
return $ case s of
|
||||
Nothing -> CallRow Nothing
|
||||
Just s' -> CallRow (Just (s', RLe))
|
||||
@ -41,11 +45,37 @@ viewCall = \case
|
||||
equalTo =
|
||||
case viewExpressionAsPattern x of
|
||||
Just x' -> do
|
||||
s <- asks (elemIndex x' . (^. sizeEqual))
|
||||
s <- asks (elemIndex x' . (^. sizeEqual) . findSizeInfo)
|
||||
return $ case s of
|
||||
Nothing -> CallRow Nothing
|
||||
Just s' -> CallRow (Just (s', REq))
|
||||
Nothing -> return (CallRow Nothing)
|
||||
findSizeInfo :: SizeInfoMap -> SizeInfo
|
||||
findSizeInfo infos =
|
||||
{-
|
||||
If the call is not to any nested function being defined, then we
|
||||
associate it with the most nested function. Without this,
|
||||
termination for mutually recursive functions doesn't work.
|
||||
|
||||
Consider:
|
||||
```
|
||||
isEven (x : Nat) : Bool :=
|
||||
let
|
||||
isEven' : Nat -> Bool
|
||||
| zero := true
|
||||
| (suc n) := isOdd' n;
|
||||
isOdd' : Nat -> Bool
|
||||
| zero := false
|
||||
| (suc n) := isEven' n;
|
||||
in isEven' x;
|
||||
```
|
||||
The call `isEven' n` inside `isOdd'` needs to be associated with
|
||||
`isOdd'`, not with `isEven`, and not just forgotten.
|
||||
-}
|
||||
fromMaybe (maybe emptySizeInfo snd . headMay $ infos ^. sizeInfoMap)
|
||||
. (lookup fref)
|
||||
. (^. sizeInfoMap)
|
||||
$ infos
|
||||
_ -> return Nothing
|
||||
where
|
||||
singletonCall :: FunctionRef -> FunCall
|
||||
|
@ -59,6 +59,7 @@ instance Scannable Expression where
|
||||
buildCallMap =
|
||||
run
|
||||
. execState emptyCallMap
|
||||
. runReader emptySizeInfoMap
|
||||
. scanTopExpression
|
||||
|
||||
runTerminationState :: TerminationState -> Sem (Termination ': r) a -> Sem r (TerminationState, a)
|
||||
@ -122,21 +123,21 @@ scanInductive i = do
|
||||
scanMutualStatement :: (Members '[State CallMap] r) => MutualStatement -> Sem r ()
|
||||
scanMutualStatement = \case
|
||||
StatementInductive i -> scanInductive i
|
||||
StatementFunction i -> scanFunctionDef i
|
||||
StatementFunction i -> runReader emptySizeInfoMap $ scanFunctionDef i
|
||||
StatementAxiom a -> scanAxiom a
|
||||
|
||||
scanAxiom :: (Members '[State CallMap] r) => AxiomDef -> Sem r ()
|
||||
scanAxiom = scanTopExpression . (^. axiomType)
|
||||
|
||||
scanFunctionDef ::
|
||||
(Members '[State CallMap] r) =>
|
||||
(Members '[State CallMap, Reader SizeInfoMap] r) =>
|
||||
FunctionDef ->
|
||||
Sem r ()
|
||||
scanFunctionDef f@FunctionDef {..} = do
|
||||
registerFunctionDef f
|
||||
runReader (Just _funDefName) $ do
|
||||
scanTypeSignature _funDefType
|
||||
scanFunctionBody _funDefBody
|
||||
scanFunctionBody _funDefName _funDefBody
|
||||
scanDefaultArgs _funDefArgsInfo
|
||||
|
||||
scanDefaultArgs ::
|
||||
@ -153,38 +154,41 @@ scanTypeSignature ::
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
|
||||
Expression ->
|
||||
Sem r ()
|
||||
scanTypeSignature = runReader emptySizeInfo . scanExpression
|
||||
scanTypeSignature = runReader emptySizeInfoMap . scanExpression
|
||||
|
||||
scanFunctionBody ::
|
||||
forall r.
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
|
||||
(Members '[State CallMap, Reader SizeInfoMap, Reader (Maybe FunctionRef)] r) =>
|
||||
FunctionName ->
|
||||
Expression ->
|
||||
Sem r ()
|
||||
scanFunctionBody topbody = go [] topbody
|
||||
scanFunctionBody funName topbody = go [] topbody
|
||||
where
|
||||
go :: [PatternArg] -> Expression -> Sem r ()
|
||||
go revArgs body = case body of
|
||||
ExpressionLambda Lambda {..} -> mapM_ goClause _lambdaClauses
|
||||
_ -> runReader (mkSizeInfo (reverse revArgs)) (scanExpression body)
|
||||
_ ->
|
||||
local
|
||||
(over sizeInfoMap ((funName, mkSizeInfo (reverse revArgs)) :))
|
||||
(scanExpression body)
|
||||
where
|
||||
goClause :: LambdaClause -> Sem r ()
|
||||
goClause (LambdaClause pats clBody) = go (reverse (toList pats) ++ revArgs) clBody
|
||||
|
||||
scanLet ::
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
|
||||
Let ->
|
||||
Sem r ()
|
||||
scanLet l = do
|
||||
mapM_ scanLetClause (l ^. letClauses)
|
||||
scanExpression (l ^. letExpression)
|
||||
|
||||
-- NOTE that we forget about the arguments of the hosting function
|
||||
scanLetClause :: (Members '[State CallMap] r) => LetClause -> Sem r ()
|
||||
scanLetClause :: (Members '[State CallMap, Reader SizeInfoMap] r) => LetClause -> Sem r ()
|
||||
scanLetClause = \case
|
||||
LetFunDef d -> scanFunctionDef d
|
||||
LetMutualBlock m -> scanMutualBlockLet m
|
||||
|
||||
scanMutualBlockLet :: (Members '[State CallMap] r) => MutualBlockLet -> Sem r ()
|
||||
scanMutualBlockLet :: (Members '[State CallMap, Reader SizeInfoMap] r) => MutualBlockLet -> Sem r ()
|
||||
scanMutualBlockLet MutualBlockLet {..} = mapM_ scanFunctionDef _mutualLet
|
||||
|
||||
scanTopExpression ::
|
||||
@ -192,18 +196,26 @@ scanTopExpression ::
|
||||
Expression ->
|
||||
Sem r ()
|
||||
scanTopExpression =
|
||||
runReader (Nothing @FunctionRef)
|
||||
. runReader emptySizeInfo
|
||||
runReader emptySizeInfoMap
|
||||
. runReader (Nothing @FunctionRef)
|
||||
. scanExpression
|
||||
|
||||
scanExpression ::
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
|
||||
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
|
||||
Expression ->
|
||||
Sem r ()
|
||||
scanExpression e =
|
||||
viewCall e >>= \case
|
||||
Just c -> do
|
||||
whenJustM (ask @(Maybe FunctionRef)) (\caller -> runReader caller (registerCall c))
|
||||
-- Are we recursively calling a function being defined?
|
||||
recCall <- asks (elem (c ^. callRef) . map fst . (^. sizeInfoMap))
|
||||
if
|
||||
| recCall ->
|
||||
runReader (c ^. callRef) (registerCall c)
|
||||
| otherwise ->
|
||||
whenJustM
|
||||
(ask @(Maybe FunctionRef))
|
||||
(\caller -> runReader caller (registerCall c))
|
||||
mapM_ (scanExpression . snd) (c ^. callArgs)
|
||||
Nothing -> case e of
|
||||
ExpressionApplication a -> directExpressions_ scanExpression a
|
||||
|
@ -12,6 +12,12 @@ data SizeInfo = SizeInfo
|
||||
|
||||
makeLenses ''SizeInfo
|
||||
|
||||
newtype SizeInfoMap = SizeInfoMap
|
||||
{ _sizeInfoMap :: [(FunctionName, SizeInfo)]
|
||||
}
|
||||
|
||||
makeLenses ''SizeInfoMap
|
||||
|
||||
emptySizeInfo :: SizeInfo
|
||||
emptySizeInfo =
|
||||
SizeInfo
|
||||
@ -19,6 +25,9 @@ emptySizeInfo =
|
||||
_sizeSmaller = mempty
|
||||
}
|
||||
|
||||
emptySizeInfoMap :: SizeInfoMap
|
||||
emptySizeInfoMap = SizeInfoMap []
|
||||
|
||||
mkSizeInfo :: [PatternArg] -> SizeInfo
|
||||
mkSizeInfo args = SizeInfo {..}
|
||||
where
|
||||
|
@ -67,7 +67,15 @@ tests =
|
||||
PosTest
|
||||
"Ignore instance arguments"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "issue2414.juvix")
|
||||
$(mkRelFile "issue2414.juvix"),
|
||||
PosTest
|
||||
"Nested local definitions"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "Nested1.juvix"),
|
||||
PosTest
|
||||
"Named arguments"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "Nested2.juvix")
|
||||
]
|
||||
|
||||
testsWithKeyword :: [PosTest]
|
||||
|
11
tests/positive/Termination/Nested1.juvix
Normal file
11
tests/positive/Termination/Nested1.juvix
Normal file
@ -0,0 +1,11 @@
|
||||
module Nested1;
|
||||
|
||||
import Stdlib.Data.List open;
|
||||
|
||||
go {A B} (f : A -> B) : List A -> List B
|
||||
| nil := nil
|
||||
| (elem :: next) :=
|
||||
let
|
||||
var1 := f elem;
|
||||
var2 := go f next;
|
||||
in var1 :: var2;
|
16
tests/positive/Termination/Nested2.juvix
Normal file
16
tests/positive/Termination/Nested2.juvix
Normal file
@ -0,0 +1,16 @@
|
||||
module Nested2;
|
||||
|
||||
type MyList A :=
|
||||
| myNil
|
||||
| myCons@{
|
||||
elem : A;
|
||||
next : MyList A;
|
||||
};
|
||||
|
||||
go {A B} (f : A -> B) : MyList A -> MyList B
|
||||
| myNil := myNil
|
||||
| myCons@{elem; next} :=
|
||||
myCons@{
|
||||
elem := f elem;
|
||||
next := go f next;
|
||||
};
|
Loading…
Reference in New Issue
Block a user