Construct ParsedModule directly in Daml Repl (#10701)

This is a relatively large change unfortunately which unfortunately
requires reimplementing parts of the logic of the typechecker & core
compilation. I don’t think it is too bad but we might want to think
over time if we can factor this better.

This fixes #10073 and fixes #10664 by referencing the exact types
instead of going via the renamer.

There are some minor changes around error messages for "module not
found" errors. This is because these are now caught in the
typechecker instead of in our own code. We could keep the errors but
it requires duplicating even more logic and I don’t really see what it
buys us so I think I prefer the approach here.

changelog_begin

- [Daml Repl] Fix a bug where bindings with out of scope types would result in error in following lines.

changelog_end
This commit is contained in:
Moritz Kiefer 2021-08-30 17:28:16 +02:00 committed by GitHub
parent bbdf16aacf
commit 6016633bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 198 additions and 126 deletions

View File

@ -12,8 +12,10 @@ module DA.Daml.Compiler.Repl
, ReplLogger(..)
) where
import TcRnTypes (tcg_rdr_env)
import BasicTypes (Boxity(..))
import FastString
import TysWiredIn (unitDataCon, unitTyCon)
import BasicTypes (Boxity(..), PromotionFlag(..), Origin(..))
import DynFlags
import Bag (bagToList, unitBag)
import Control.Applicative
import Control.Concurrent.Extra
@ -25,8 +27,13 @@ import qualified Control.Monad.State.Strict as State
import Control.Monad.Trans.Maybe
import DA.Daml.Compiler.Output (printDiagnostics)
import qualified DA.Daml.LF.Ast as LF
import qualified DA.Daml.LF.InferSerializability as Serializability
import qualified DA.Daml.LF.Simplifier as LF
import qualified DA.Daml.LF.TypeChecker as LF
import DA.Daml.LF.Ast.Optics (packageRefs)
import qualified DA.Daml.LF.ReplClient as ReplClient
import DA.Daml.LFConversion (convertModule)
import DA.Daml.LFConversion.UtilLF (buildPackage)
import DA.Daml.LFConversion.UtilGHC
import DA.Daml.Options.Types
import qualified DA.Daml.Preprocessor.Records as Preprocessor
@ -45,22 +52,28 @@ import Data.Semigroup (Last(..))
import qualified Data.Text as T
import qualified Data.Text.IO as T
import Development.IDE.Core.API
import Development.IDE.Core.Compile (compileModule, typecheckModule, RunSimplifier(..))
import Development.IDE.Core.RuleTypes
import Development.IDE.Core.RuleTypes.Daml
import Development.IDE.Core.Rules.Daml (diagsToIdeResult, getDamlLfVersion, getExternalPackages, ideErrorPretty)
import Development.IDE.Core.Service
import Development.IDE.Core.Shake
import Development.IDE.GHC.Util
import Development.IDE.Types.Diagnostics
import Development.IDE.Types.Location
import Development.IDE.Types.Options
import ErrUtils
import GHC
import HscTypes (HscEnv(..), mkPrintUnqualified)
import GHC hiding (typecheckModule)
import GHC.LanguageExtensions.Type
import HscTypes (HscEnv(..), HscSource(HsSrcFile))
import Language.Haskell.GhclibParserEx.Parse
import qualified Language.LSP.Types as LSP
import Module (unitIdString)
import OccName (OccSet, occName, elemOccSet, mkOccSet, mkVarOcc)
import Outputable (parens, ppr, showSDoc, showSDocForUser)
import Module (mainUnitId, unitIdString)
import OccName
import Outputable (ppr, showSDoc)
import qualified Outputable
import RdrName (mkRdrUnqual)
import RdrName (getRdrName, mkRdrUnqual)
import SrcLoc
import qualified System.Console.Repline as Repl
import System.Exit
import System.IO.Extra
@ -156,8 +169,10 @@ toTuplePat vars = noLoc $
TuplePat noExt [noLoc (VarPat noExt $ noLoc v) | v <- vars] Boxed
toTupleExpr :: LPat GhcPs -> LHsExpr GhcPs
toTupleExpr pat = noLoc $
ExplicitTuple noExt [noLoc (Present noExt (noLoc $ HsVar noExt (noLoc v))) | v <- vars] Boxed
toTupleExpr pat = case vars of
[] -> noLoc $ HsVar noExt (noLoc $ getRdrName unitDataCon)
[var] -> noLoc $ HsVar noExt (noLoc var)
_ -> noLoc $ ExplicitTuple noExt [noLoc (Present noExt (noLoc $ HsVar noExt (noLoc v))) | v <- vars] Boxed
where vars = collectPatBinders pat
-- | Type for the statements we support.
@ -230,7 +245,7 @@ topologicalSort lfPkgs = map toPkg $ topSort $ transposeG graph
-- 'importInsert'.
--
-- This avoids redundant import lines and eases removing of module imports.
newtype Imports = Imports (Map.Map ModuleName [ImportDecl GhcPs])
newtype Imports = Imports { getImports :: Map.Map ModuleName [ImportDecl GhcPs] }
-- | Add an import declaration.
--
@ -277,7 +292,6 @@ importToList (Imports imports) = concat $ Map.elems imports
data ReplState = ReplState
{ imports :: !Imports
, printUnqualified :: PrintUnqualified
, bindings :: ![(LPat GhcPs, Type)]
, lineNumber :: !Int
}
@ -379,7 +393,7 @@ runRepl importPkgs opts replClient logger ideState = do
dflags <- liftIO $
hsc_dflags . hscEnv <$>
runAction ideState (use_ GhcSession $ lineFilePath initialLineNumber)
(_, tcr) <-
_ <-
runExceptT (typecheckImports dflags (importFromList imports) initialLineNumber)
>>= \case
Left err -> do
@ -390,7 +404,6 @@ runRepl importPkgs opts replClient logger ideState = do
{ imports = importFromList imports
, bindings = []
, lineNumber = 0
, printUnqualified = getPrintUnqualified dflags tcr
}
let replM = Repl.evalReplOpts Repl.ReplOpts
{ banner = const (pure "daml> ")
@ -413,24 +426,24 @@ runRepl importPkgs opts replClient logger ideState = do
-> ReplClient.ReplResponseType
-> ExceptT Error ReplM ()
handleStmt dflags line stmt rspType = do
ReplState {imports, bindings, lineNumber, printUnqualified} <- State.get
ReplState {imports, bindings, lineNumber} <- State.get
supportedStmt <- maybe (throwError (UnsupportedStatement line)) pure (validateStmt stmt)
let rendering = renderModule dflags printUnqualified imports lineNumber bindings supportedStmt
let rendering = renderModule imports lineNumber bindings supportedStmt
(lfMod, tmrModule -> tcMod) <- printDelayedDiagnostics $ case (rspType, rendering) of
(ReplClient.ReplText, BindingRendering t) ->
tryTypecheck lineNumber (T.pack t)
tryTypecheck dflags lineNumber t
(ReplClient.ReplText, BodyRenderings {..}) ->
withExceptT getLast
$ withExceptT Last (tryTypecheck lineNumber (T.pack unitScript))
<!> withExceptT Last (tryTypecheck lineNumber (T.pack printableScript))
<!> withExceptT Last (tryTypecheck lineNumber (T.pack arbitraryScript))
<!> withExceptT Last (tryTypecheck lineNumber (T.pack purePrintableExpr))
$ withExceptT Last (tryTypecheck dflags lineNumber unitScript)
<!> withExceptT Last (tryTypecheck dflags lineNumber printableScript)
<!> withExceptT Last (tryTypecheck dflags lineNumber arbitraryScript)
<!> withExceptT Last (tryTypecheck dflags lineNumber purePrintableExpr)
(ReplClient.ReplJson, BindingRendering _) ->
throwError (ExpectedExpression line, [])
(ReplClient.ReplJson, BodyRenderings {..}) ->
withExceptT getLast
$ withExceptT Last (tryTypecheck lineNumber (T.pack arbitraryScript))
<!> withExceptT Last (tryTypecheck lineNumber (T.pack pureArbitraryExpr))
$ withExceptT Last (tryTypecheck dflags lineNumber arbitraryScript)
<!> withExceptT Last (tryTypecheck dflags lineNumber pureArbitraryExpr)
-- Type of the statement so we can give it a type annotation
-- and avoid incurring a typeclass constraint.
stmtTy <- maybe (throwError TypeError) pure (exprTy $ tm_typechecked_source tcMod)
@ -459,9 +472,10 @@ runRepl importPkgs opts replClient logger ideState = do
liftIO $ mapM_ (printDiagnostics stdout) diags
pure (Left err)
Right r -> pure (Right r)
tryTypecheck :: (MonadIO m, MonadError (Error, [[FileDiagnostic]]) m) => Int -> T.Text -> m (LF.Module, TcModuleResult)
tryTypecheck lineNumber t = do
liftIO $ setBufferModified ideState (lineFilePath lineNumber) $ Just t
tryTypecheck :: (MonadIO m, MonadError (Error, [[FileDiagnostic]]) m) => DynFlags -> Int -> ParsedSource -> m (LF.Module, TcModuleResult)
tryTypecheck dflags lineNumber source = do
let file = lineFilePath lineNumber
-- liftIO $ setBufferModified ideState file $ Just t
-- We need to temporarily suppress diagnostics since we use type errors
-- to decide what to do. If a case succeeds we immediately print all diagnostics.
-- If it fails, we return them and only print them once everything failed.
@ -469,9 +483,33 @@ runRepl importPkgs opts replClient logger ideState = do
-- here we don't want to use the `useE` function that uses cached results
let useE' k = MaybeT . use k
let writeDiags diags = atomicModifyIORef diagsRef (\f -> (f . (diags:), ()))
r <- liftIO $ withReplLogger logger writeDiags $ runAction ideState $ runMaybeT $
(,) <$> useE' GenerateDalf (lineFilePath lineNumber)
<*> useE' TypeCheck (lineFilePath lineNumber)
let handleIdeResult :: IdeResult r -> MaybeT Action r
handleIdeResult (diags, r) = do
liftIO $ writeDiags diags
MaybeT (pure r)
r <- liftIO $ withReplLogger logger writeDiags $ runAction ideState $ runMaybeT $ do
lfVersion <- lift getDamlLfVersion
let pm = toParsedModule dflags source
IdeOptions { optDefer = defer } <- lift getIdeOptions
packageState <- hscEnv <$> useE' GhcSession file
tm <- handleIdeResult =<< liftIO (typecheckModule defer packageState [] pm)
(safeMode, cgGuts, details) <- handleIdeResult =<< liftIO (compileModule (RunSimplifier False) packageState [] tm)
let core = cgGutsToCoreModule safeMode cgGuts details
PackageMap pkgMap <- useE' GeneratePackageMap file
stablePkgs <- lift $ useNoFile_ GenerateStablePackages
case convertModule lfVersion pkgMap (Map.map LF.dalfPackageId stablePkgs) False file core details of
Left diag -> handleIdeResult ([diag], Nothing)
Right v -> do
pkgs <- lift $ getExternalPackages file
let world = LF.initWorldSelf pkgs (buildPackage (optMbPackageName opts) (optMbPackageVersion opts) lfVersion [])
let simplified = LF.simplifyModule world lfVersion v
case Serializability.inferModule world lfVersion simplified of
Left err -> handleIdeResult ([ideErrorPretty file err], Nothing)
Right dalf -> do
let (_diags, checkResult) = diagsToIdeResult file $ LF.checkModule world lfVersion dalf
case checkResult of
Nothing -> MaybeT (pure Nothing)
Just () -> pure (dalf, tm)
diags <- liftIO $ ($ []) <$> readIORef diagsRef
case r of
Nothing -> throwError (TypeError, diags)
@ -479,8 +517,11 @@ runRepl importPkgs opts replClient logger ideState = do
liftIO $ mapM_ (printDiagnostics stdout) diags
pure r
typecheckImports dflags imports line =
printDelayedDiagnostics $ tryTypecheck line $
T.pack (unlines $ moduleHeader dflags imports line)
printDelayedDiagnostics $
tryTypecheck
dflags
line
(buildModule (lineModuleName line) imports [])
handleImport
:: DynFlags
-> ImportDecl GhcPs
@ -551,10 +592,9 @@ runRepl importPkgs opts replClient logger ideState = do
addImports dflags additional = do
ReplState {imports, lineNumber} <- State.get
let newImports = foldl' (flip importInsert) imports additional
(_, tcr) <- typecheckImports dflags newImports lineNumber
_ <- typecheckImports dflags newImports lineNumber
State.modify $ \s -> s
{ imports = newImports
, printUnqualified = getPrintUnqualified dflags tcr
}
removeImports
:: DynFlags
@ -566,16 +606,11 @@ runRepl importPkgs opts replClient logger ideState = do
newImports = foldl' (flip importDelete) imports modules
unless (null unknown) $
throwError $ NotImportedModules unknown
(_, tcr) <- typecheckImports dflags newImports lineNumber
_ <- typecheckImports dflags newImports lineNumber
lift $ State.modify $ \s -> s
{ imports = newImports
, printUnqualified = getPrintUnqualified dflags tcr
}
getPrintUnqualified dflags tcr =
let gblRdrEnv = tcg_rdr_env $ fst $ tm_internals_ $ tmrModule tcr
in mkPrintUnqualified dflags gblRdrEnv
exprTy :: LHsBinds GhcTc -> Maybe Type
exprTy binds = listToMaybe
[ argTy
@ -596,23 +631,23 @@ lineModuleName i = "Line" <> show i
-- | Possible ways to render a module. We take the first one that typechecks
data ModuleRenderings
= BindingRendering String -- ^ x <- e with e :: Script a for some a
= BindingRendering ParsedSource -- ^ x <- e with e :: Script a for some a
| BodyRenderings
{ unitScript :: String
{ unitScript :: ParsedSource
-- ^ e :: Script (). Here we do not print the result.
, printableScript :: String
, printableScript :: ParsedSource
-- ^ e :: Script a with for some a that is an instance of Show. Here
-- we print the result.
, arbitraryScript :: String
, arbitraryScript :: ParsedSource
-- ^ e :: Script a for some a that may not be an instance of Show.
, purePrintableExpr :: String
, purePrintableExpr :: ParsedSource
-- ^ e :: a for some a that is an instance of Show. Here we
-- print the result. Note that we do not support
-- non-printable pure expressions since there is no
-- reason to run them.
, pureArbitraryExpr :: String
, pureArbitraryExpr :: ParsedSource
-- ^ e :: a for some a that may not be an instance of Show.
} deriving Show
}
moduleImports
:: DynFlags
@ -624,101 +659,60 @@ moduleImports dflags imports =
where
renderImport imp = showSDoc dflags (ppr imp)
moduleHeader
:: DynFlags
-> Imports
-> Int
-> [String]
moduleHeader dflags imports line =
[ "{-# OPTIONS_GHC -Wno-unused-imports -Wno-partial-type-signatures #-}"
, "{-# LANGUAGE PartialTypeSignatures #-}"
, "module " <> lineModuleName line <> " where"
] <> moduleImports dflags imports
renderModule
:: DynFlags
-> PrintUnqualified
-> Imports
:: Imports
-> Int
-> [(LPat GhcPs, Type)]
-> SupportedStatement
-> ModuleRenderings
renderModule dflags printUnqualified imports line binds stmt = case stmt of
renderModule imports line binds stmt = case stmt of
BindStatement pat expr ->
BindingRendering $ unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script _"
, exprLhs
, Outputable.nest 2 $ ppr (scriptStmt (Just pat) expr returnAp)
]
]
BindingRendering
(buildExprModule file imports binds
(noLoc $ HsWildCardTy noExt)
(scriptStmt (Just pat) expr returnAp))
BodyStatement expr ->
BodyRenderings
{ unitScript = unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script ()"
, exprLhs
, Outputable.nest 2 $ ppr (scriptStmt Nothing expr returnAp)
]
]
, printableScript = unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script Text"
, exprLhs
, Outputable.nest 2 $ ppr (scriptStmt Nothing expr returnShowAp)
]
]
, arbitraryScript = unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script _"
, exprLhs
, Outputable.nest 2 $ ppr (scriptStmt Nothing expr returnAp)
]
]
, purePrintableExpr = unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script Text"
, exprLhs
, Outputable.nest 2 $ ppr $
returnShowAp expr
]
]
, pureArbitraryExpr = unlines $
moduleHeader dflags imports line <>
[showSDoc' . Outputable.vcat $
[ exprTy "Script _"
, exprLhs
, Outputable.nest 2 $ ppr (returnAp $ noLoc $ HsPar noExt expr)
]
]
{ unitScript =
buildExprModule file imports binds
(noLoc $ HsTyVar noExt NotPromoted (noLoc $ getRdrName unitTyCon))
(scriptStmt Nothing expr returnAp)
, printableScript =
buildExprModule file imports binds
(noLoc $ HsTyVar noExt NotPromoted (noLoc $ Unqual $ mkTcOcc "Text"))
(scriptStmt Nothing expr returnShowAp)
, arbitraryScript =
buildExprModule file imports binds
(noLoc $ HsWildCardTy noExt)
(scriptStmt Nothing expr returnAp)
, purePrintableExpr =
buildExprModule file imports binds
(noLoc $ HsTyVar noExt NotPromoted (noLoc $ Unqual $ mkTcOcc "Text"))
(returnShowAp expr)
, pureArbitraryExpr =
buildExprModule file imports binds
(noLoc $ HsWildCardTy noExt)
(returnAp $ noLoc $ HsPar noExt expr)
}
LetStatement binding ->
let retExpr = case binding of
FunBinding f _ -> noLoc $ HsVar noExt f
PatBinding pat _ -> toTupleExpr pat
in BindingRendering $ unlines $
moduleHeader dflags imports line <>
[ showSDoc' $ exprTy "Script _"
, showSDoc' exprLhs
, showSDoc' $ Outputable.nest 2 $ ppr $ HsDo noExt DoExpr $ noLoc
expr = noLoc $ HsDo noExt DoExpr $ noLoc
[ noLoc $ LetStmt noExt $ toLocalBinds binding
, noLoc $ LastStmt noExt (returnAp retExpr) False noSyntaxExpr
]
]
in BindingRendering
(buildExprModule file imports binds (noLoc $ HsWildCardTy noExt) expr)
where
showSDoc' = showSDocForUser dflags printUnqualified
renderPat pat = ppr pat
renderTy ty = parens (ppr ty) <> " -> "
file = lineModuleName line
-- build a script statement using the given wrapper (either `return` or `show`)
-- to wrap the final result.
scriptStmt :: Maybe (LPat GhcPs) -> LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs) -> LHsExpr GhcPs
scriptStmt mbPat expr wrapper =
let pat = fromMaybe (noLoc $ VarPat noExt $ noLoc $ mkRdrUnqual $ mkVarOcc "result") mbPat
in HsDo noExt DoExpr $ noLoc
in noLoc $ HsDo noExt DoExpr $ noLoc
[ noLoc $ BindStmt noExt pat expr noSyntaxExpr noSyntaxExpr
, noLoc $ LastStmt noExt (wrapper $ toTupleExpr pat) False noSyntaxExpr
]
@ -731,12 +725,81 @@ renderModule dflags printUnqualified imports line binds stmt = case stmt of
noLoc $ HsApp noExt showExpr $
noLoc $ HsPar noExt x
showExpr = noLoc $ HsVar noExt (noLoc $ mkRdrUnqual $ mkVarOcc "show")
exprLhs = "expr " <> Outputable.hsep (map (renderPat . fst) binds) <> " = "
exprTy :: Outputable.SDoc -> Outputable.SDoc
exprTy res =
"expr : " <>
Outputable.hcat (map (renderTy . snd) binds) <>
res
buildExprModule :: String -> Imports -> [(LPat GhcPs, Type)] -> LHsType GhcPs -> LHsExpr GhcPs -> GenLocated SrcSpan (HsModule GhcPs)
buildExprModule file imports binds ty expr = buildModule file imports
[ noLoc $
SigD noExt $
TypeSig noExt [noLoc (Unqual $ mkVarOcc "expr")] $
mkLHsSigWcType $
let resTy =
mkHsAppTy
(noLoc $ HsTyVar noExt NotPromoted (noLoc $ Unqual $ mkTcOcc "Script"))
ty
in foldr (\(_, arg) acc -> noLoc $ HsFunTy noExt (noLoc $ XHsType $ NHsCoreTy arg) acc) resTy binds
, noLoc $
ValD noExt $ FunBind
noExt
(noLoc (Unqual $ mkVarOcc "expr"))
(mkMatchGroup FromSource [mkSimpleMatch (mkPrefixFunRhs exprName) (map fst binds) expr])
idHsWrapper
[]
]
where
exprName :: Located RdrName
exprName = noLoc (Unqual $ mkVarOcc "expr")
buildModule :: String -> Imports -> [LHsDecl GhcPs] -> GenLocated SrcSpan (HsModule GhcPs)
buildModule file imports decls = L dummySrcSpan HsModule
{ hsmodName = Just (noLoc (mkModuleName file))
, hsmodExports = Nothing
, hsmodImports = Preprocessor.onImports $
noLoc (simpleImportDecl (mkModuleName "Daml.Script")) :
map noLoc (concat (Map.elems $ getImports imports))
, hsmodDecls = decls
, hsmodDeprecMessage = Nothing
, hsmodHaddockModHeader = Nothing
}
where
-- For some reason GHC requires a RealSrcSpan for typechecking so we make up one.
dummySrcSpan = RealSrcSpan $ mkRealSrcSpan (mkRealSrcLoc (mkFastString (file <> ".daml")) 0 1) (mkRealSrcLoc (mkFastString file) 100000 1)
toParsedModule :: DynFlags -> GenLocated SrcSpan (HsModule GhcPs) -> ParsedModule
toParsedModule dflags source = ParsedModule
{ pm_mod_summary = ModSummary
{ ms_mod = mkModule mainUnitId (unLoc $ fromJust $ hsmodName $ unLoc source)
, ms_hsc_src = HsSrcFile
, ms_obj_date = Nothing
, ms_iface_date = Nothing
, ms_hie_date = Nothing
, ms_srcimps = []
, ms_textual_imps = []
, ms_parsed_mod = Nothing
, ms_hspp_opts =
flip xopt_set PartialTypeSignatures $
foldl'
wopt_unset
dflags
[ Opt_WarnUnusedImports
, Opt_WarnPartialTypeSignatures
]
, ms_hspp_buf = Nothing
-- These fields are not required for our usage
-- so rather than making up values we set them to bottom
-- making sure we error out if that changes.
, ms_location = ModLocation
{ ml_hs_file = Nothing
, ml_hi_file = error "ml_hi_file is not required"
, ml_hie_file = error "ml_hie_file is not required"
, ml_obj_file = error "ml_obj_file is not required"
}
, ms_hs_date = error "hs date is not required"
, ms_hspp_file = error "hspp file is not required"
}
, pm_parsed_source = source
, pm_extra_src_files = []
, pm_annotations = (Map.empty, Map.empty)
}
instance Applicative m => Apply (Repl.HaskelineT m) where
liftF2 = liftA2

View File

@ -7,6 +7,7 @@ module DA.Daml.Preprocessor.Records
, mkImport
, recordDotPreprocessor
, onExp
, onImports
) where

View File

@ -216,6 +216,7 @@ functionalTests replClient replLogger serviceOut options ideState = describe "re
, matchOutput "^Source:.*$"
, matchOutput "^Severity:.*$"
, matchOutput "^Message:.*$"
, matchOutput "^.*Line0.daml.*$"
, matchOutput "^.*Could not find module.*$"
, matchOutput "^.*It is not a module.*$"
, input "import DA.Time"
@ -361,6 +362,13 @@ functionalTests replClient replLogger serviceOut options ideState = describe "re
, input "lookup1 1"
, matchOutput "^Some 2$"
]
, testInteraction' "out of scope type"
[ -- import a function to build a map but not the type itself
input "import DA.Map (fromList)"
, input "let m = fromList [(0,0)]"
, input "m"
, matchOutput "Map \\[\\(0,0\\)\\]"
]
]
where
testInteraction' testName steps =