1
1
mirror of https://github.com/anoma/juvix.git synced 2024-12-01 00:04:58 +03:00

Implement lambda lifting without letrec (#1494)

Co-authored-by: Paul Cadman <git@paulcadman.dev>
Co-authored-by: Łukasz Czajka <62751+lukaszcz@users.noreply.github.com>
This commit is contained in:
janmasrovira 2022-09-12 12:45:40 +02:00 committed by GitHub
parent 2eb51ce1c3
commit 3262906772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 179 additions and 21 deletions

View File

@ -30,7 +30,7 @@ emptyInfoTable =
data IdentKind = IdentSym Symbol | IdentTag Tag
data IdentifierInfo = IdentifierInfo
{ _identifierName :: Name,
{ _identifierName :: Maybe Name,
_identifierSymbol :: Symbol,
_identifierType :: Type,
-- _identifierArgsNum will be used often enough to justify avoiding recomputation

View File

@ -60,7 +60,8 @@ runInfoTableBuilder tab =
return (UserTag (s ^. stateNextUserTag - 1))
RegisterIdent ii -> do
modify' (over stateInfoTable (over infoIdentifiers (HashMap.insert (ii ^. identifierSymbol) ii)))
modify' (over stateInfoTable (over identMap (HashMap.insert (ii ^. (identifierName . nameText)) (IdentSym (ii ^. identifierSymbol)))))
whenJust (ii ^? identifierName . _Just . nameText) $ \name ->
modify' (over stateInfoTable (over identMap (HashMap.insert name (IdentSym (ii ^. identifierSymbol)))))
RegisterConstructor ci -> do
modify' (over stateInfoTable (over infoConstructors (HashMap.insert (ci ^. constructorTag) ci)))
modify' (over stateInfoTable (over identMap (HashMap.insert (ci ^. (constructorName . nameText)) (IdentTag (ci ^. constructorTag)))))

View File

@ -9,6 +9,7 @@ module Juvix.Compiler.Core.Extra
)
where
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Equality
@ -29,7 +30,7 @@ freeVars f = ufoldNA reassemble go
where
go k = \case
NVar var@(Var _ idx)
| idx >= k -> NVar <$> f var
| idx >= k -> NVar <$> f (shiftVar (-k) var)
n -> pure n
getIdents :: Node -> HashSet Ident
@ -49,14 +50,48 @@ countFreeVarOccurrences idx = gatherN go 0
NVar (Var _ idx') | idx' == idx + k -> acc + 1
_ -> acc
shiftVar :: Index -> Var -> Var
shiftVar m = over varIndex (+ m)
-- | increase all free variable indices by a given value
shift :: Index -> Node -> Node
shift 0 = id
shift m = umapN go
where
go k n = case n of
NVar (Var i idx) | idx >= k -> mkVar i (idx + m)
_ -> n
go k = \case
NVar v
| v ^. varIndex >= k -> NVar (shiftVar m v)
n -> n
-- | Prism for NLam
_NLam :: SimpleFold Node Lambda
_NLam f = \case
NLam l -> NLam <$> f l
n -> pure n
-- | Fold over all of the transitive descendants of a Node, including itself.
cosmos :: SimpleFold Node Node
cosmos f = ufoldA reassemble f
-- | The list should not contain repeated indices. The 'Info' corresponds to the
-- binder of the variable.
captureFreeVars :: [(Index, Info)] -> Node -> Node
captureFreeVars fv
| n == 0 = id
| otherwise = mkLambdas infos . mapFreeVars
where
(indices, infos) = unzip fv
n = length fv
s :: HashMap Index Index
s = HashMap.fromList (zip indices [0 ..])
mapFreeVars :: Node -> Node
mapFreeVars = dmapN go
where
go :: Index -> Node -> Node
go k = \case
NVar (Var i u)
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
m -> m
-- | substitute a term t for the free variable with de Bruijn index 0, avoiding
-- variable capture; shifts all free variabes with de Bruijn index > 0 by -1 (as

View File

@ -1,10 +1,7 @@
-- | This file defines Infos stored in JuvixCore Nodes. The Info data structure
-- maps an info type to an info of that type.
module Juvix.Compiler.Core.Info where
{-
This file defines Infos stored in JuvixCore Nodes. The Info data structure
maps an info type to an info of that type.
-}
import Data.Dynamic
import Data.HashMap.Strict qualified as HashMap
import Juvix.Prelude
@ -14,6 +11,7 @@ class Typeable a => IsInfo a
newtype Info = Info
{ _infoMap :: HashMap TypeRep Dynamic
}
deriving newtype (Semigroup, Monoid)
type Key = Proxy

View File

@ -267,7 +267,7 @@ instance PrettyCode InfoTable where
where
ppDef :: Symbol -> Node -> Sem r (Doc Ann)
ppDef s n = do
sym' <- maybe (return (pretty s)) ppCode (tbl ^? infoIdentifiers . at s . _Just . identifierName)
sym' <- maybe (return (pretty s)) ppCode (tbl ^? infoIdentifiers . at s . _Just . identifierName . _Just)
body' <- ppCode n
return (kwDef <+> sym' <+> kwAssign <+> body')

View File

@ -4,10 +4,68 @@ module Juvix.Compiler.Core.Transformation.LambdaLifting
)
where
import Juvix.Compiler.Core.Data.BinderList (BinderList)
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.InfoTableBuilder
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Info.TypeInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base
lambdaLiftNode :: Node -> Sem r Node
lambdaLiftNode = return
lambdaLiftNode :: forall r. Member InfoTableBuilder r => BinderList Info -> Node -> Sem r Node
lambdaLiftNode aboveBl top =
mkLambdas topArgs <$> dmapLRM' (topArgsBinderList <> aboveBl, go) body
where
(topArgs, body) = unfoldLambdas top
topArgsBinderList :: BinderList Info
topArgsBinderList = BL.fromList topArgs
typeFromArgs :: [ArgumentInfo] -> Type
typeFromArgs = \case
[] -> mkDynamic' -- change this when we have type info about the body
(a : as) -> mkPi' argTy (typeFromArgs as)
where
argTy = fromMaybe mkDynamic' (a ^. argumentType)
-- extracts the argument info from the binder
argInfo :: Info -> ArgumentInfo
argInfo i =
ArgumentInfo
{ _argumentName = (^. infoName) <$> Info.lookup (Proxy @NameInfo) i,
_argumentType = (^. infoType) <$> Info.lookup (Proxy @TypeInfo) i,
_argumentIsImplicit = False
}
go :: BinderList Info -> Node -> Sem r Recur
go bl = \case
l@NLam {} -> do
l' <- lambdaLiftNode bl l
let freevars = toList (getFreeVars l')
freevarsAssocs :: [(Index, Info)]
freevarsAssocs = [(i, BL.lookup i bl) | i <- map (^. varIndex) freevars]
fBody' = captureFreeVars freevarsAssocs l'
argsInfo :: [ArgumentInfo]
argsInfo = map (argInfo . snd) freevarsAssocs
f <- freshSymbol
registerIdent
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length freevars,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False
}
registerIdentNode f fBody'
let fApp = mkApps' (mkIdent mempty f) (map NVar freevars)
return (End fApp)
m -> return (Recur m)
lambdaLifting :: InfoTable -> InfoTable
lambdaLifting = run . mapT' lambdaLiftNode
lambdaLifting = run . mapT' (lambdaLiftNode mempty)
-- | True if lambdas are only found at the top level
isLifted :: Node -> Bool
isLifted = not . hasNestedLambdas
where
hasNestedLambdas :: Node -> Bool
hasNestedLambdas = has (cosmos . _NLam) . snd . unfoldLambdas'

View File

@ -141,7 +141,7 @@ statementDef = do
name <- lift $ freshName KNameFunction txt i
let info =
IdentifierInfo
{ _identifierName = name,
{ _identifierName = Just name,
_identifierSymbol = sym,
_identifierType = mkDynamic',
_identifierArgsNum = 0,

View File

@ -331,6 +331,16 @@ allElements = [minBound .. maxBound]
readerState :: forall a r x. (Member (State a) r) => Sem (Reader a ': r) x -> Sem r x
readerState m = get >>= (`runReader` m)
infixr 3 .&&.
(.&&.) :: (a -> Bool) -> (a -> Bool) -> a -> Bool
(a .&&. b) c = a c && b c
infixr 2 .||.
(.||.) :: (a -> Bool) -> (a -> Bool) -> a -> Bool
(a .||. b) c = a c || b c
class CanonicalProjection a b where
project :: a -> b

View File

@ -10,7 +10,7 @@ import Prettyprinter.Render.Text qualified as Text
data Test = Test
{ _testName :: String,
_testCoreFile :: FilePath,
_testExpectedFile :: FilePath,
_testAssertion :: InfoTable -> Assertion,
_testTransformations :: [TransformationId]
}
@ -28,12 +28,22 @@ toTestDescr t@Test {..} =
_testAssertion = Single (coreTransAssertion t)
}
assertExpectedOutput :: FilePath -> InfoTable -> Assertion
assertExpectedOutput testExpectedFile r = do
expected <- readFile testExpectedFile
let actualOutput = Text.renderStrict (toTextStream (ppOut opts r))
assertEqDiff ("Check: output = " <> testExpectedFile) actualOutput expected
where
opts :: Options
opts =
defaultOptions
{ _optShowDeBruijnIndices = True
}
coreTransAssertion :: Test -> Assertion
coreTransAssertion Test {..} = do
r <- applyTransformations [LambdaLifting] <$> parseFile _testCoreFile
expected <- readFile _testExpectedFile
let actualOutput = Text.renderStrict (toTextStream (ppOutDefault r))
assertEqDiff ("Check: EVAL output = " <> _testExpectedFile) actualOutput expected
_testAssertion r
parseFile :: FilePath -> IO InfoTable
parseFile f = fst <$> fromRightIO show (runParser "" f emptyInfoTable <$> readFile f)

View File

@ -1,9 +1,35 @@
module Core.Transformation.Lifting (allTests) where
import Base
import Core.Transformation.Base
import Juvix.Compiler.Core.Transformation
allTests :: TestTree
allTests = testGroup "Lambda lifting" tests
pipe :: [TransformationId]
pipe = [LambdaLifting]
dir :: FilePath
dir = "lambda-lifting"
liftTest :: String -> FilePath -> FilePath -> TestTree
liftTest _testName _testCoreFile _testExpectedFile =
fromTest
Test
{ _testTransformations = pipe,
_testCoreFile = dir </> _testCoreFile,
_testName,
_testAssertion = assertExpectedOutput expectedFile
}
where
expectedFile = dir </> _testExpectedFile
tests :: [TestTree]
tests = []
tests =
[ liftTest
("Lambda lifting without let rec " <> i)
("test" <> i <> ".jvc")
("test" <> i <> ".out")
| i <- map show [1 :: Int .. 3]
]

View File

@ -0,0 +1 @@
def t1 := \g \f f (\x \y \z g x);

View File

@ -0,0 +1,3 @@
-- IdentContext
def 1 ≔ λg λx λy λz g$3 x$2
def t1 ≔ λg λf f$0 (!1 g$1)

View File

@ -0,0 +1 @@
def t2 := \r \s r (\x \y s y (\z z y x (\w w x) (\e y e x y)));

View File

@ -0,0 +1,6 @@
-- IdentContext
def 1 ≔ λx λw w$0 x$1
def 2 ≔ λy λx λe y$2 e$0 x$1 y$2
def 3 ≔ λy λx λz z$0 y$2 x$1 (!1 x$1) (!2 x$1 y$2)
def 4 ≔ λs λx λy s$2 y$0 (!3 x$1 y$0)
def t2 ≔ λr λs r$1 (!4 s$0)

View File

@ -0,0 +1,3 @@
def const := \x \y x;
def id := \x x;
def t3 := \r \s const (\x x) (id (\x r));

View File

@ -0,0 +1,6 @@
-- IdentContext
def 3 ≔ λx x$0
def 4 ≔ λr λx r$1
def const ≔ λx λy x$1
def id ≔ λx x$0
def t3 ≔ λr λs const !3 (id (!4 r$1))