1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 14:13:27 +03:00

Preserve the target type in letrec lifting (#1945)

- Closes #1887
This commit is contained in:
janmasrovira 2023-03-30 18:02:37 +02:00 committed by GitHub
parent e1e4216504
commit 9e9a884fdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 120 additions and 62 deletions

View File

@ -96,11 +96,14 @@ _NLam f = \case
cosmos :: SimpleFold Node Node
cosmos f = ufoldA reassemble f
-- | The list should not contain repeated indices.
-- if fv = x1, x2, .., xn
-- the result is of the form λx1 λx2 .. λ xn b
captureFreeVars :: [(Index, Binder)] -> Node -> Node
captureFreeVars freevars = goBinders freevars . mapFreeVars
-- | The free vars are given in the context of the node.
captureFreeVarsType :: [(Index, Binder)] -> (Node, Type) -> (Node, Type)
captureFreeVarsType freevars (n, ty) =
let bodyTy = mapFreeVars ty
body' = mapFreeVars n
in ( mkLambdasB captureBinders' body',
mkPis captureBinders' bodyTy
)
where
mapFreeVars :: Node -> Node
mapFreeVars = dmapN go
@ -112,25 +115,33 @@ captureFreeVars freevars = goBinders freevars . mapFreeVars
NVar (Var i u)
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
m -> m
goBinders :: [(Index, Binder)] -> Node -> Node
goBinders fv = case unsnoc fv of
Nothing -> id
Just (fvs, (idx, bin)) -> goBinders fvs . mkLambdaB (mapBinder idx bin)
captureBinders' :: [Binder]
captureBinders' = goBinders freevars []
where
indices = map fst fv
mapBinder :: Index -> Binder -> Binder
mapBinder binderIndex = over binderType (dmapN go)
goBinders :: [(Index, Binder)] -> [Binder] -> [Binder]
goBinders fv acc = case unsnoc fv of
Nothing -> acc
Just (fvs, (idx, bin)) -> goBinders fvs (mapBinder idx bin : acc)
where
go :: Index -> Node -> Node
go k = \case
NVar u
| u ^. varIndex >= k ->
let uCtx = u ^. varIndex - k + binderIndex + 1
err = error ("impossible: could not find " <> show uCtx <> " in " <> show indices)
u' = length indices - 2 - fromMaybe err (elemIndex uCtx indices) + k
in NVar (set varIndex u' u)
m -> m
indices = map fst fv
mapBinder :: Index -> Binder -> Binder
mapBinder binderIndex = over binderType (dmapN go)
where
go :: Index -> Node -> Node
go k = \case
NVar u
| u ^. varIndex >= k ->
let uCtx = u ^. varIndex - k + binderIndex + 1
err = error ("impossible: could not find " <> show uCtx <> " in " <> show indices)
u' = length indices - 2 - fromMaybe err (elemIndex uCtx indices) + k
in NVar (set varIndex u' u)
m -> m
-- | The list should not contain repeated indices.
-- if fv = x1, x2, .., xn
-- the result is of the form λx1 λx2 .. λ xn b
captureFreeVars :: [(Index, Binder)] -> Node -> Node
captureFreeVars freevars n = fst (captureFreeVarsType freevars (n, mkDynamic'))
-- | Captures all free variables of a node. It also returns the list of captured
-- variables in left-to-right order: if snd is of the form λxλy... then fst is
@ -140,6 +151,12 @@ captureFreeVarsCtx bl n =
let assocs = freeVarsCtx bl n
in (assocs, captureFreeVars (map (first (^. varIndex)) assocs) n)
captureFreeVarsCtxType :: BinderList Binder -> (Node, Type) -> ([(Var, Binder)], (Node, Type))
captureFreeVarsCtxType bl (n, ty) =
let assocs = freeVarsCtx bl n
assocsi = map (first (^. varIndex)) assocs
in (assocs, captureFreeVarsType assocsi (n, ty))
freeVarsCtxMany' :: BinderList Binder -> [Node] -> [Var]
freeVarsCtxMany' bl = map fst . freeVarsCtxMany bl

View File

@ -10,6 +10,7 @@ import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType)
lambdaLiftBinder :: Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder bl = traverseOf binderType (lambdaLiftNode bl)
@ -22,10 +23,8 @@ lambdaLiftNode aboveBl top =
(topArgs, body) = unfoldLambdas top
in goTop aboveBl body topArgs
where
typeFromArgs :: [ArgumentInfo] -> Type
typeFromArgs = \case
[] -> mkDynamic' -- change this when we have type info about the body
(a : as) -> mkPi mempty (binderFromArgumentInfo a) (typeFromArgs as)
nodeType :: Node -> Sem r Type
nodeType n = flip computeNodeType n <$> getInfoTable
goTop :: BinderList Binder -> Node -> [LambdaLhs] -> Sem r Node
goTop bl body = \case
@ -58,13 +57,14 @@ lambdaLiftNode aboveBl top =
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
ty <- nodeType fBody'
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierType = ty,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
@ -78,10 +78,13 @@ lambdaLiftNode aboveBl top =
goLetRec letr = do
let defs :: [Node]
defs = letr ^.. letRecValues . each . letItemValue
defsTypes :: [Type]
defsTypes = letr ^.. letRecValues . each . letItemBinder . binderType
ndefs :: Int
ndefs = length defs
binders :: [Binder]
binders = letr ^.. letRecValues . each . letItemBinder
letRecBinders' :: [Binder] <- mapM (lambdaLiftBinder bl) binders
topSyms :: [Symbol] <- forM defs (const freshSymbol)
let bl' :: BinderList Binder
@ -98,7 +101,7 @@ lambdaLiftNode aboveBl top =
helper :: Var -> Maybe (Var, Binder)
helper v
| v ^. varIndex < ndefs = Nothing
| otherwise = Just (set varIndex idx' v, BL.lookup idx' bl)
| otherwise = Just (shiftVar (-ndefs) v, BL.lookup idx' bl)
where
idx' = v ^. varIndex - ndefs
@ -120,7 +123,10 @@ lambdaLiftNode aboveBl top =
declareTopSyms =
sequence_
[ do
let topBody = captureFreeVars (map (first (^. varIndex)) recItemsFreeVars) b
let (topBody, topTy) =
captureFreeVarsType
(map (first (^. varIndex)) recItemsFreeVars)
(b, bty)
argsInfo :: [ArgumentInfo]
argsInfo =
map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas topBody))
@ -131,13 +137,19 @@ lambdaLiftNode aboveBl top =
{ _identifierSymbol = sym,
_identifierName = name,
_identifierLocation = itemBinder ^. binderLocation,
_identifierType = typeFromArgs argsInfo,
_identifierType = topTy,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
| ((sym, name), (itemBinder, b)) <- zipExact topSymsWithName (zipExact letRecBinders' liftedDefs)
| ((sym, name), (itemBinder, (b, bty))) <-
zipExact
topSymsWithName
( zipExact
letRecBinders'
(zipExact liftedDefs defsTypes)
)
]
declareTopSyms

View File

@ -1,14 +1,13 @@
module Juvix.Compiler.Core.Translation.Stripped.FromCore (fromCore) where
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core
import Juvix.Compiler.Core.Data.Stripped.InfoTable qualified as Stripped
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Extra.Stripped.Base qualified as Stripped
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Language
import Juvix.Compiler.Core.Language.Stripped qualified as Stripped
import Juvix.Compiler.Core.Pretty
fromCore :: InfoTable -> Stripped.InfoTable
fromCore tab =
@ -87,7 +86,15 @@ translateFunction :: Int -> Node -> Stripped.Node
translateFunction argsNum node =
let (k, body) = unfoldLambdas' node
in if
| k /= argsNum -> error "wrong number of arguments"
| k /= argsNum ->
error
( "wrong number of arguments. argsNum = "
<> show argsNum
<> ", unfoldLambdas = "
<> show k
<> "\nNode = "
<> ppTrace node
)
| otherwise -> translateNode body
translateNode :: Node -> Stripped.Node

View File

@ -14,7 +14,7 @@ ignoredTests =
"Fast exponentiation",
"Nested 'case', 'let' and 'if' with variable capture",
"Mutual recursion",
"LetRec",
"LetRec - fib, fact",
"Big numbers"
]

View File

@ -233,7 +233,7 @@ tests =
$(mkRelFile "test039.jvc")
$(mkRelFile "out/test039.out"),
PosTest
"LetRec"
"LetRec - fib, fact"
$(mkRelDir ".")
$(mkRelFile "test040.jvc")
$(mkRelFile "out/test040.out"),

View File

@ -1,34 +1,34 @@
-- letrec
-- letrec - fib, fact
def sum := letrec sum := \x if x = 0 then 0 else x + sum (x - 1) in sum;
def sum : Int -> Int := letrec sum : Int -> Int := \x if x = 0 then 0 else x + sum (x - 1) in sum;
def fact := \x
letrec fact' := \x \acc if x = 0 then acc else fact' (x - 1) (acc * x)
def fact : Int -> Int := \x
letrec fact' : Int -> Int -> Int := \x \acc if x = 0 then acc else fact' (x - 1) (acc * x)
in fact' x 1;
def fib :=
letrec fib' := \n \x \y if n = 0 then x else fib' (n - 1) y (x + y)
def fib : Int -> Int :=
letrec fib' : Int -> Int -> Int -> Int := \n \x \y if n = 0 then x else fib' (n - 1) y (x + y)
in \n fib' n 0 1;
def writeLn := \x write x >> write "\n";
def mutrec :=
let two := 2 in
let one := 1 in
def mutrec : IO :=
let two : Int := 2 in
let one : Int := 1 in
letrec[f g h]
f := \x {
f : Int -> Int := \x {
if x < one then
one
else
g (x - one) + two * x
};
g := \x {
g : Int -> Int := \x {
if x < one then
one
else
x + h (x - one)
};
h := \x letrec z := {
h : Int -> Int := \x letrec z : Int := {
if x < one then
one
else
@ -36,7 +36,7 @@ def mutrec :=
} in z;
in writeLn (f 5) >> writeLn (f 10) >> writeLn (f 100) >> writeLn (g 5) >> writeLn (h 5);
letrec x := 3
letrec x : Int := 3
in
writeLn x >>
writeLn (sum 10000) >>
@ -47,9 +47,9 @@ writeLn (fib 10) >>
writeLn (fib 100) >>
writeLn (fib 1000) >>
mutrec >>
letrec x := 1 in
letrec x' := x + letrec x := 2 in x in
letrec x := x' * x' in
letrec y := x + 2 in
letrec z := x + y in
letrec x : Int := 1 in
letrec x' : Int := x + letrec x : Int := 2 in x in
letrec x : Int := x' * x' in
letrec y : Int := x + 2 in
letrec z : Int := x + y in
writeLn (x + y + z)

View File

@ -1,11 +1,33 @@
-- dependent lambda-abstractions
def fun := λ(A : Type) λ(x : A) let f := λ(h : A → A) h (h x) in f (λ(y : A) x);
def fun :
Π A : Type,
A → A :=
λ(A : Type)
λ(x : A)
let f : (A → A) → A := λ(h : A → A) h (h x) in
f (λ(y : A) x);
def fun' : Π T : Type → Type, Π X : Type, Π A : Type, Any :=
λ(T : Type → Type) λ(_ : Type) λ(A : Type) λ(B : T A) λ(x : B)
let f := λ(g : B → B) g (g x) in
let h := λ(b : B) λ(a : A) a * b - b in
f (λ(y : B) h y x);
def helper : Int → Int → Int :=
λ(a : Int)
λ(b : Int)
a * b - b;
fun Int 2 + fun' (λ(A : Type) A) Bool Int Int 3
def fun' : Π T : Type → Type,
Π unused : Type,
Π C : Type,
Π A : Type,
(T A → A → C)
→ A
→ C :=
λ(T : Type → Type)
λ(unused : Type)
λ(C : Type)
λ(A : Type)
λ(mhelper : T A → A → C)
λ(a' : A)
let f : (A → A) → A := λ(g : A → A) g (g a') in
let h : A → A → C := λ(a1 : A) λ(a2 : A) mhelper a2 a1 in
f (λ(y : A) h y a');
fun Int 2 + fun' (λ(A : Type) A) Bool Int Int helper 3