mirror of
https://github.com/anoma/juvix.git
synced 2024-12-01 00:04:58 +03:00
Implement lambda lifting without letrec (#1494)
Co-authored-by: Paul Cadman <git@paulcadman.dev> Co-authored-by: Łukasz Czajka <62751+lukaszcz@users.noreply.github.com>
This commit is contained in:
parent
2eb51ce1c3
commit
3262906772
@ -30,7 +30,7 @@ emptyInfoTable =
|
||||
data IdentKind = IdentSym Symbol | IdentTag Tag
|
||||
|
||||
data IdentifierInfo = IdentifierInfo
|
||||
{ _identifierName :: Name,
|
||||
{ _identifierName :: Maybe Name,
|
||||
_identifierSymbol :: Symbol,
|
||||
_identifierType :: Type,
|
||||
-- _identifierArgsNum will be used often enough to justify avoiding recomputation
|
||||
|
@ -60,7 +60,8 @@ runInfoTableBuilder tab =
|
||||
return (UserTag (s ^. stateNextUserTag - 1))
|
||||
RegisterIdent ii -> do
|
||||
modify' (over stateInfoTable (over infoIdentifiers (HashMap.insert (ii ^. identifierSymbol) ii)))
|
||||
modify' (over stateInfoTable (over identMap (HashMap.insert (ii ^. (identifierName . nameText)) (IdentSym (ii ^. identifierSymbol)))))
|
||||
whenJust (ii ^? identifierName . _Just . nameText) $ \name ->
|
||||
modify' (over stateInfoTable (over identMap (HashMap.insert name (IdentSym (ii ^. identifierSymbol)))))
|
||||
RegisterConstructor ci -> do
|
||||
modify' (over stateInfoTable (over infoConstructors (HashMap.insert (ci ^. constructorTag) ci)))
|
||||
modify' (over stateInfoTable (over identMap (HashMap.insert (ci ^. (constructorName . nameText)) (IdentTag (ci ^. constructorTag)))))
|
||||
|
@ -9,6 +9,7 @@ module Juvix.Compiler.Core.Extra
|
||||
)
|
||||
where
|
||||
|
||||
import Data.HashMap.Strict qualified as HashMap
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Juvix.Compiler.Core.Extra.Base
|
||||
import Juvix.Compiler.Core.Extra.Equality
|
||||
@ -29,7 +30,7 @@ freeVars f = ufoldNA reassemble go
|
||||
where
|
||||
go k = \case
|
||||
NVar var@(Var _ idx)
|
||||
| idx >= k -> NVar <$> f var
|
||||
| idx >= k -> NVar <$> f (shiftVar (-k) var)
|
||||
n -> pure n
|
||||
|
||||
getIdents :: Node -> HashSet Ident
|
||||
@ -49,14 +50,48 @@ countFreeVarOccurrences idx = gatherN go 0
|
||||
NVar (Var _ idx') | idx' == idx + k -> acc + 1
|
||||
_ -> acc
|
||||
|
||||
shiftVar :: Index -> Var -> Var
|
||||
shiftVar m = over varIndex (+ m)
|
||||
|
||||
-- | increase all free variable indices by a given value
|
||||
shift :: Index -> Node -> Node
|
||||
shift 0 = id
|
||||
shift m = umapN go
|
||||
where
|
||||
go k n = case n of
|
||||
NVar (Var i idx) | idx >= k -> mkVar i (idx + m)
|
||||
_ -> n
|
||||
go k = \case
|
||||
NVar v
|
||||
| v ^. varIndex >= k -> NVar (shiftVar m v)
|
||||
n -> n
|
||||
|
||||
-- | Prism for NLam
|
||||
_NLam :: SimpleFold Node Lambda
|
||||
_NLam f = \case
|
||||
NLam l -> NLam <$> f l
|
||||
n -> pure n
|
||||
|
||||
-- | Fold over all of the transitive descendants of a Node, including itself.
|
||||
cosmos :: SimpleFold Node Node
|
||||
cosmos f = ufoldA reassemble f
|
||||
|
||||
-- | The list should not contain repeated indices. The 'Info' corresponds to the
|
||||
-- binder of the variable.
|
||||
captureFreeVars :: [(Index, Info)] -> Node -> Node
|
||||
captureFreeVars fv
|
||||
| n == 0 = id
|
||||
| otherwise = mkLambdas infos . mapFreeVars
|
||||
where
|
||||
(indices, infos) = unzip fv
|
||||
n = length fv
|
||||
s :: HashMap Index Index
|
||||
s = HashMap.fromList (zip indices [0 ..])
|
||||
mapFreeVars :: Node -> Node
|
||||
mapFreeVars = dmapN go
|
||||
where
|
||||
go :: Index -> Node -> Node
|
||||
go k = \case
|
||||
NVar (Var i u)
|
||||
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
|
||||
m -> m
|
||||
|
||||
-- | substitute a term t for the free variable with de Bruijn index 0, avoiding
|
||||
-- variable capture; shifts all free variabes with de Bruijn index > 0 by -1 (as
|
||||
|
@ -1,10 +1,7 @@
|
||||
-- | This file defines Infos stored in JuvixCore Nodes. The Info data structure
|
||||
-- maps an info type to an info of that type.
|
||||
module Juvix.Compiler.Core.Info where
|
||||
|
||||
{-
|
||||
This file defines Infos stored in JuvixCore Nodes. The Info data structure
|
||||
maps an info type to an info of that type.
|
||||
-}
|
||||
|
||||
import Data.Dynamic
|
||||
import Data.HashMap.Strict qualified as HashMap
|
||||
import Juvix.Prelude
|
||||
@ -14,6 +11,7 @@ class Typeable a => IsInfo a
|
||||
newtype Info = Info
|
||||
{ _infoMap :: HashMap TypeRep Dynamic
|
||||
}
|
||||
deriving newtype (Semigroup, Monoid)
|
||||
|
||||
type Key = Proxy
|
||||
|
||||
|
@ -267,7 +267,7 @@ instance PrettyCode InfoTable where
|
||||
where
|
||||
ppDef :: Symbol -> Node -> Sem r (Doc Ann)
|
||||
ppDef s n = do
|
||||
sym' <- maybe (return (pretty s)) ppCode (tbl ^? infoIdentifiers . at s . _Just . identifierName)
|
||||
sym' <- maybe (return (pretty s)) ppCode (tbl ^? infoIdentifiers . at s . _Just . identifierName . _Just)
|
||||
body' <- ppCode n
|
||||
return (kwDef <+> sym' <+> kwAssign <+> body')
|
||||
|
||||
|
@ -4,10 +4,68 @@ module Juvix.Compiler.Core.Transformation.LambdaLifting
|
||||
)
|
||||
where
|
||||
|
||||
import Juvix.Compiler.Core.Data.BinderList (BinderList)
|
||||
import Juvix.Compiler.Core.Data.BinderList qualified as BL
|
||||
import Juvix.Compiler.Core.Data.InfoTableBuilder
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Info qualified as Info
|
||||
import Juvix.Compiler.Core.Info.NameInfo
|
||||
import Juvix.Compiler.Core.Info.TypeInfo
|
||||
import Juvix.Compiler.Core.Pretty
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
lambdaLiftNode :: Node -> Sem r Node
|
||||
lambdaLiftNode = return
|
||||
lambdaLiftNode :: forall r. Member InfoTableBuilder r => BinderList Info -> Node -> Sem r Node
|
||||
lambdaLiftNode aboveBl top =
|
||||
mkLambdas topArgs <$> dmapLRM' (topArgsBinderList <> aboveBl, go) body
|
||||
where
|
||||
(topArgs, body) = unfoldLambdas top
|
||||
topArgsBinderList :: BinderList Info
|
||||
topArgsBinderList = BL.fromList topArgs
|
||||
typeFromArgs :: [ArgumentInfo] -> Type
|
||||
typeFromArgs = \case
|
||||
[] -> mkDynamic' -- change this when we have type info about the body
|
||||
(a : as) -> mkPi' argTy (typeFromArgs as)
|
||||
where
|
||||
argTy = fromMaybe mkDynamic' (a ^. argumentType)
|
||||
-- extracts the argument info from the binder
|
||||
argInfo :: Info -> ArgumentInfo
|
||||
argInfo i =
|
||||
ArgumentInfo
|
||||
{ _argumentName = (^. infoName) <$> Info.lookup (Proxy @NameInfo) i,
|
||||
_argumentType = (^. infoType) <$> Info.lookup (Proxy @TypeInfo) i,
|
||||
_argumentIsImplicit = False
|
||||
}
|
||||
go :: BinderList Info -> Node -> Sem r Recur
|
||||
go bl = \case
|
||||
l@NLam {} -> do
|
||||
l' <- lambdaLiftNode bl l
|
||||
let freevars = toList (getFreeVars l')
|
||||
freevarsAssocs :: [(Index, Info)]
|
||||
freevarsAssocs = [(i, BL.lookup i bl) | i <- map (^. varIndex) freevars]
|
||||
fBody' = captureFreeVars freevarsAssocs l'
|
||||
argsInfo :: [ArgumentInfo]
|
||||
argsInfo = map (argInfo . snd) freevarsAssocs
|
||||
f <- freshSymbol
|
||||
registerIdent
|
||||
IdentifierInfo
|
||||
{ _identifierSymbol = f,
|
||||
_identifierName = Nothing,
|
||||
_identifierType = typeFromArgs argsInfo,
|
||||
_identifierArgsNum = length freevars,
|
||||
_identifierArgsInfo = argsInfo,
|
||||
_identifierIsExported = False
|
||||
}
|
||||
registerIdentNode f fBody'
|
||||
let fApp = mkApps' (mkIdent mempty f) (map NVar freevars)
|
||||
return (End fApp)
|
||||
m -> return (Recur m)
|
||||
|
||||
lambdaLifting :: InfoTable -> InfoTable
|
||||
lambdaLifting = run . mapT' lambdaLiftNode
|
||||
lambdaLifting = run . mapT' (lambdaLiftNode mempty)
|
||||
|
||||
-- | True if lambdas are only found at the top level
|
||||
isLifted :: Node -> Bool
|
||||
isLifted = not . hasNestedLambdas
|
||||
where
|
||||
hasNestedLambdas :: Node -> Bool
|
||||
hasNestedLambdas = has (cosmos . _NLam) . snd . unfoldLambdas'
|
||||
|
@ -141,7 +141,7 @@ statementDef = do
|
||||
name <- lift $ freshName KNameFunction txt i
|
||||
let info =
|
||||
IdentifierInfo
|
||||
{ _identifierName = name,
|
||||
{ _identifierName = Just name,
|
||||
_identifierSymbol = sym,
|
||||
_identifierType = mkDynamic',
|
||||
_identifierArgsNum = 0,
|
||||
|
@ -331,6 +331,16 @@ allElements = [minBound .. maxBound]
|
||||
readerState :: forall a r x. (Member (State a) r) => Sem (Reader a ': r) x -> Sem r x
|
||||
readerState m = get >>= (`runReader` m)
|
||||
|
||||
infixr 3 .&&.
|
||||
|
||||
(.&&.) :: (a -> Bool) -> (a -> Bool) -> a -> Bool
|
||||
(a .&&. b) c = a c && b c
|
||||
|
||||
infixr 2 .||.
|
||||
|
||||
(.||.) :: (a -> Bool) -> (a -> Bool) -> a -> Bool
|
||||
(a .||. b) c = a c || b c
|
||||
|
||||
class CanonicalProjection a b where
|
||||
project :: a -> b
|
||||
|
||||
|
@ -10,7 +10,7 @@ import Prettyprinter.Render.Text qualified as Text
|
||||
data Test = Test
|
||||
{ _testName :: String,
|
||||
_testCoreFile :: FilePath,
|
||||
_testExpectedFile :: FilePath,
|
||||
_testAssertion :: InfoTable -> Assertion,
|
||||
_testTransformations :: [TransformationId]
|
||||
}
|
||||
|
||||
@ -28,12 +28,22 @@ toTestDescr t@Test {..} =
|
||||
_testAssertion = Single (coreTransAssertion t)
|
||||
}
|
||||
|
||||
assertExpectedOutput :: FilePath -> InfoTable -> Assertion
|
||||
assertExpectedOutput testExpectedFile r = do
|
||||
expected <- readFile testExpectedFile
|
||||
let actualOutput = Text.renderStrict (toTextStream (ppOut opts r))
|
||||
assertEqDiff ("Check: output = " <> testExpectedFile) actualOutput expected
|
||||
where
|
||||
opts :: Options
|
||||
opts =
|
||||
defaultOptions
|
||||
{ _optShowDeBruijnIndices = True
|
||||
}
|
||||
|
||||
coreTransAssertion :: Test -> Assertion
|
||||
coreTransAssertion Test {..} = do
|
||||
r <- applyTransformations [LambdaLifting] <$> parseFile _testCoreFile
|
||||
expected <- readFile _testExpectedFile
|
||||
let actualOutput = Text.renderStrict (toTextStream (ppOutDefault r))
|
||||
assertEqDiff ("Check: EVAL output = " <> _testExpectedFile) actualOutput expected
|
||||
_testAssertion r
|
||||
|
||||
parseFile :: FilePath -> IO InfoTable
|
||||
parseFile f = fst <$> fromRightIO show (runParser "" f emptyInfoTable <$> readFile f)
|
||||
|
@ -1,9 +1,35 @@
|
||||
module Core.Transformation.Lifting (allTests) where
|
||||
|
||||
import Base
|
||||
import Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation
|
||||
|
||||
allTests :: TestTree
|
||||
allTests = testGroup "Lambda lifting" tests
|
||||
|
||||
pipe :: [TransformationId]
|
||||
pipe = [LambdaLifting]
|
||||
|
||||
dir :: FilePath
|
||||
dir = "lambda-lifting"
|
||||
|
||||
liftTest :: String -> FilePath -> FilePath -> TestTree
|
||||
liftTest _testName _testCoreFile _testExpectedFile =
|
||||
fromTest
|
||||
Test
|
||||
{ _testTransformations = pipe,
|
||||
_testCoreFile = dir </> _testCoreFile,
|
||||
_testName,
|
||||
_testAssertion = assertExpectedOutput expectedFile
|
||||
}
|
||||
where
|
||||
expectedFile = dir </> _testExpectedFile
|
||||
|
||||
tests :: [TestTree]
|
||||
tests = []
|
||||
tests =
|
||||
[ liftTest
|
||||
("Lambda lifting without let rec " <> i)
|
||||
("test" <> i <> ".jvc")
|
||||
("test" <> i <> ".out")
|
||||
| i <- map show [1 :: Int .. 3]
|
||||
]
|
||||
|
1
tests/Core/positive/lambda-lifting/test1.jvc
Normal file
1
tests/Core/positive/lambda-lifting/test1.jvc
Normal file
@ -0,0 +1 @@
|
||||
def t1 := \g \f f (\x \y \z g x);
|
3
tests/Core/positive/lambda-lifting/test1.out
Normal file
3
tests/Core/positive/lambda-lifting/test1.out
Normal file
@ -0,0 +1,3 @@
|
||||
-- IdentContext
|
||||
def 1 ≔ λg λx λy λz g$3 x$2
|
||||
def t1 ≔ λg λf f$0 (!1 g$1)
|
1
tests/Core/positive/lambda-lifting/test2.jvc
Normal file
1
tests/Core/positive/lambda-lifting/test2.jvc
Normal file
@ -0,0 +1 @@
|
||||
def t2 := \r \s r (\x \y s y (\z z y x (\w w x) (\e y e x y)));
|
6
tests/Core/positive/lambda-lifting/test2.out
Normal file
6
tests/Core/positive/lambda-lifting/test2.out
Normal file
@ -0,0 +1,6 @@
|
||||
-- IdentContext
|
||||
def 1 ≔ λx λw w$0 x$1
|
||||
def 2 ≔ λy λx λe y$2 e$0 x$1 y$2
|
||||
def 3 ≔ λy λx λz z$0 y$2 x$1 (!1 x$1) (!2 x$1 y$2)
|
||||
def 4 ≔ λs λx λy s$2 y$0 (!3 x$1 y$0)
|
||||
def t2 ≔ λr λs r$1 (!4 s$0)
|
3
tests/Core/positive/lambda-lifting/test3.jvc
Normal file
3
tests/Core/positive/lambda-lifting/test3.jvc
Normal file
@ -0,0 +1,3 @@
|
||||
def const := \x \y x;
|
||||
def id := \x x;
|
||||
def t3 := \r \s const (\x x) (id (\x r));
|
6
tests/Core/positive/lambda-lifting/test3.out
Normal file
6
tests/Core/positive/lambda-lifting/test3.out
Normal file
@ -0,0 +1,6 @@
|
||||
-- IdentContext
|
||||
def 3 ≔ λx x$0
|
||||
def 4 ≔ λr λx r$1
|
||||
def const ≔ λx λy x$1
|
||||
def id ≔ λx x$0
|
||||
def t3 ≔ λr λs const !3 (id (!4 r$1))
|
Loading…
Reference in New Issue
Block a user