1
1
mirror of https://github.com/anoma/juvix.git synced 2024-12-02 23:43:01 +03:00

Eta expansion at the top of each core function definition (#1481) (#1571)

This commit is contained in:
janmasrovira 2022-11-14 16:03:28 +01:00 committed by GitHub
parent d1ec8926c4
commit 169155690b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 108 additions and 11 deletions

View File

@ -2,7 +2,7 @@ module Commands.Dev.Core.Read.Options where
import Commands.Dev.Core.Eval.Options qualified as Eval
import CommonOptions
import Evaluator qualified as Evaluator
import Evaluator qualified
import Juvix.Compiler.Core.Data.TransformationId.Parser
import Juvix.Compiler.Core.Pretty.Options qualified as Core
@ -58,7 +58,7 @@ parseCoreReadOptions = do
<> short 't'
<> value mempty
<> metavar "[Transform]"
<> help "comma sep list of transformations. Available: lifting"
<> help "comma sep list of transformations. Available: lifting, top-eta-expand, identity"
)
_coreReadInputFile <- parseInputJuvixCoreFile
pure CoreReadOptions {..}

View File

@ -13,7 +13,7 @@ data InfoTableBuilder m a where
RegisterInductive :: InductiveInfo -> InfoTableBuilder m ()
RegisterIdentNode :: Symbol -> Node -> InfoTableBuilder m ()
RegisterMain :: Symbol -> InfoTableBuilder m ()
SetIdentArgsInfo :: Symbol -> [ArgumentInfo] -> InfoTableBuilder m ()
OverIdentArgsInfo :: Symbol -> ([ArgumentInfo] -> [ArgumentInfo]) -> InfoTableBuilder m ()
GetIdent :: Text -> InfoTableBuilder m (Maybe IdentKind)
GetInfoTable :: InfoTableBuilder m InfoTable
@ -31,12 +31,15 @@ checkSymbolDefined sym = do
tab <- getInfoTable
return $ HashMap.member sym (tab ^. identContext)
runInfoTableBuilder :: MkIdentIndex -> InfoTable -> Sem (InfoTableBuilder ': r) a -> Sem r (InfoTable, a)
setIdentArgsInfo :: Member InfoTableBuilder r => Symbol -> [ArgumentInfo] -> Sem r ()
setIdentArgsInfo sym = overIdentArgsInfo sym . const
runInfoTableBuilder :: forall r a. MkIdentIndex -> InfoTable -> Sem (InfoTableBuilder ': r) a -> Sem r (InfoTable, a)
runInfoTableBuilder mkIdentIndex tab =
runState tab
. reinterpret interp
where
interp :: InfoTableBuilder m a -> Sem (State InfoTable : r) a
interp :: InfoTableBuilder m b -> Sem (State InfoTable : r) b
interp = \case
FreshSymbol -> do
s <- get
@ -60,7 +63,8 @@ runInfoTableBuilder mkIdentIndex tab =
modify' (over identContext (HashMap.insert sym node))
RegisterMain sym -> do
modify' (set infoMain (Just sym))
SetIdentArgsInfo sym argsInfo -> do
OverIdentArgsInfo sym f -> do
argsInfo <- f <$> gets (^. infoIdentifiers . at sym . _Just . identifierArgsInfo)
modify' (set (infoIdentifiers . at sym . _Just . identifierArgsInfo) argsInfo)
modify' (set (infoIdentifiers . at sym . _Just . identifierArgsNum) (length argsInfo))
modify' (over infoIdentifiers (HashMap.adjust (over identifierType (expandType (map (^. argumentType) argsInfo))) sym))

View File

@ -4,5 +4,6 @@ import Juvix.Prelude
data TransformationId
= LambdaLifting
| TopEtaExpand
| Identity
deriving stock (Data)

View File

@ -30,3 +30,4 @@ transformation :: MonadParsec e Text m => m TransformationId
transformation =
symbol "lifting" $> LambdaLifting
<|> symbol "identity" $> Identity
<|> symbol "top-eta-expand" $> TopEtaExpand

View File

@ -3,6 +3,7 @@ module Juvix.Compiler.Core.Transformation
module Juvix.Compiler.Core.Transformation,
module Juvix.Compiler.Core.Transformation.Eta,
module Juvix.Compiler.Core.Transformation.LambdaLifting,
module Juvix.Compiler.Core.Transformation.TopEtaExpand,
module Juvix.Compiler.Core.Data.TransformationId,
)
where
@ -12,6 +13,7 @@ import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Eta
import Juvix.Compiler.Core.Transformation.Identity
import Juvix.Compiler.Core.Transformation.LambdaLifting
import Juvix.Compiler.Core.Transformation.TopEtaExpand
applyTransformations :: [TransformationId] -> InfoTable -> InfoTable
applyTransformations ts tbl = foldl' (flip appTrans) tbl ts
@ -20,3 +22,4 @@ applyTransformations ts tbl = foldl' (flip appTrans) tbl ts
appTrans = \case
LambdaLifting -> lambdaLifting
Identity -> identity
TopEtaExpand -> topEtaExpand

View File

@ -15,12 +15,12 @@ type Transformation = InfoTable -> InfoTable
mapT :: (Symbol -> Node -> Node) -> InfoTable -> InfoTable
mapT f tab = tab {_identContext = HashMap.mapWithKey f (tab ^. identContext)}
mapT' :: (Node -> Sem (InfoTableBuilder ': r) Node) -> InfoTable -> Sem r InfoTable
mapT' :: (Symbol -> Node -> Sem (InfoTableBuilder ': r) Node) -> InfoTable -> Sem r InfoTable
mapT' f tab =
fmap fst $
runInfoTableBuilder (^. nameText) tab $
mapM_
(\(k, v) -> f v >>= registerIdentNode k)
(\(k, v) -> f k v >>= registerIdentNode k)
(HashMap.toList (tab ^. identContext))
walkT :: Applicative f => (Symbol -> Node -> f ()) -> InfoTable -> f ()

View File

@ -8,4 +8,4 @@ import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base
identity :: InfoTable -> InfoTable
identity = run . mapT' return
identity = run . mapT' (const return)

View File

@ -139,7 +139,7 @@ lambdaLiftNode aboveBl top =
return (Recur res)
lambdaLifting :: InfoTable -> InfoTable
lambdaLifting = run . mapT' (lambdaLiftNode mempty)
lambdaLifting = run . mapT' (const (lambdaLiftNode mempty))
-- | True if lambdas are only found at the top level
nodeIsLifted :: Node -> Bool

View File

@ -0,0 +1,42 @@
module Juvix.Compiler.Core.Transformation.TopEtaExpand where
import Juvix.Compiler.Core.Data.InfoTableBuilder
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
topEtaExpand :: InfoTable -> InfoTable
topEtaExpand info = run (mapT' go info)
where
go :: Symbol -> Node -> Sem '[InfoTableBuilder] Node
go sym body = case info ^. infoIdentifiers . at sym of
Nothing -> return body
Just idenInfo ->
let args :: [PiLhs]
args = fst (unfoldPi (idenInfo ^. identifierType))
in skipLambdas args body
where
skipLambdas :: [PiLhs] -> Node -> Sem '[InfoTableBuilder] Node
skipLambdas args node = case args of
[] -> return node
(_ : as) -> case node of
NLam l -> NLam <$> traverseOf lambdaBody (skipLambdas as) l
_ -> do
let newArgsInfo :: [ArgumentInfo]
newArgsInfo = map toArgumentInfo as
overIdentArgsInfo sym (++ newArgsInfo)
return (expand node (reverse args))
toArgumentInfo :: PiLhs -> ArgumentInfo
toArgumentInfo pi =
ArgumentInfo
{ _argumentName = pi ^. piLhsBinder . binderName,
_argumentType = pi ^. piLhsBinder . binderType,
_argumentIsImplicit = Explicit
}
expand :: Node -> [PiLhs] -> Node
expand n lhs = reLambdas (map lambdaFromPi lhs) body'
where
len = length lhs
body' = mkApps' (shift len n) [mkVar' v | v <- reverse [0 .. len - 1]]
-- We keep the name and type. We drop the other info
lambdaFromPi :: PiLhs -> LambdaLhs
lambdaFromPi pi = LambdaLhs mempty (pi ^. piLhsBinder)

View File

@ -244,5 +244,10 @@ tests =
"Dependent lambda-abstractions"
"."
"test043.jvc"
"out/test043.out"
"out/test043.out",
PosTest
"Eta-expansion"
"."
"test044.jvc"
"out/test044.out"
]

View File

@ -3,11 +3,13 @@ module Core.Transformation where
import Base
import Core.Transformation.Identity qualified as Identity
import Core.Transformation.Lifting qualified as Lifting
import Core.Transformation.TopEtaExpand qualified as TopEtaExpand
allTests :: TestTree
allTests =
testGroup
"JuvixCore transformations"
[ Identity.allTests,
TopEtaExpand.allTests,
Lifting.allTests
]

View File

@ -0,0 +1,21 @@
module Core.Transformation.TopEtaExpand (allTests) where
import Base
import Core.Eval.Positive qualified as Eval
import Core.Transformation.Base
import Juvix.Compiler.Core.Transformation
allTests :: TestTree
allTests = testGroup "Top eta expand" (map liftTest Eval.tests)
pipe :: [TransformationId]
pipe = [TopEtaExpand]
liftTest :: Eval.PosTest -> TestTree
liftTest _testEval =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = const (return ()),
_testEval
}

View File

@ -0,0 +1 @@
18

View File

@ -0,0 +1,17 @@
-- eta-expansion
def compose : (int -> int) -> (int -> int) -> int -> int := \f \g \x f (g x);
def expand : any -> int -> any := \f \x f;
def f : int -> int := (+) 1;
def g : int -> int -> int := \z compose f (\x x - z);
def h : int -> int := compose f (g 3);
def j : int -> int -> int := g;
def k : int -> int -> int -> int := expand j;
h 13 + j 2 3 + k 9 4 7