diff --git a/src/GHC/SourceGen/Binds/Internal.hs b/src/GHC/SourceGen/Binds/Internal.hs index 8ed4913..9a568c6 100644 --- a/src/GHC/SourceGen/Binds/Internal.hs +++ b/src/GHC/SourceGen/Binds/Internal.hs @@ -18,6 +18,7 @@ import SrcLoc (Located) import PlaceHolder (PlaceHolder(..)) #endif +import GHC.SourceGen.Pat.Internal (parenthesize) import GHC.SourceGen.Syntax.Internal -- | A binding definition inside of a @let@ or @where@ clause. @@ -85,7 +86,8 @@ matchGroup context matches = Generated where mkMatch :: RawMatch -> Match' (Located HsExpr') - mkMatch r = noExt Match context (map builtPat $ rawMatchPats r) + mkMatch r = noExt Match context + (map builtPat $ map parenthesize $ rawMatchPats r) #if !MIN_VERSION_ghc(8,4,0) -- The GHC docs say: "A type signature for the result of the match." -- The parsing step produces 'Nothing' for this field. diff --git a/src/GHC/SourceGen/Pat.hs b/src/GHC/SourceGen/Pat.hs index f6bdb8d..b601a1e 100644 --- a/src/GHC/SourceGen/Pat.hs +++ b/src/GHC/SourceGen/Pat.hs @@ -17,13 +17,11 @@ module GHC.SourceGen.Pat , sigP ) where -import SrcLoc (unLoc) import HsTypes import HsPat hiding (LHsRecField') import GHC.SourceGen.Name.Internal -import GHC.SourceGen.Overloaded (par) -import GHC.SourceGen.Expr.Internal (litNeedsParen, overLitNeedsParen) +import GHC.SourceGen.Pat.Internal import GHC.SourceGen.Syntax.Internal import GHC.SourceGen.Type.Internal (sigWcType) @@ -48,32 +46,6 @@ conP :: RdrNameStr -> [Pat'] -> Pat' conP c xs = ConPatIn (valueRdrName c) $ PrefixCon $ map (builtPat . parenthesize) xs --- Note: GHC>=8.6 inserts parentheses automatically when pretty-printing patterns. --- When we stop supporting lower versions, we may be able to simplify this. -parenthesize :: Pat' -> Pat' -parenthesize p - | needsPar p = par p - | otherwise = p - -needsPar :: Pat' -> Bool -#if MIN_VERSION_ghc(8,6,0) -needsPar (LitPat _ l) = litNeedsParen l -needsPar (NPat _ l _ _) = overLitNeedsParen $ unLoc l -#else -needsPar (LitPat l) = litNeedsParen l -needsPar (NPat l _ _ _) = overLitNeedsParen $ unLoc l -#endif -needsPar (ConPatIn _ (PrefixCon xs)) = not $ null xs -needsPar (ConPatIn _ (InfixCon _ _)) = True -needsPar ConPatOut{} = True -#if MIN_VERSION_ghc(8,6,0) -needsPar SigPat{} = True -#else -needsPar SigPatIn{} = True -needsPar SigPatOut{} = True -#endif -needsPar _ = False - recordConP :: RdrNameStr -> [(RdrNameStr, Pat')] -> Pat' recordConP c fs = ConPatIn (valueRdrName c) diff --git a/src/GHC/SourceGen/Pat/Internal.hs b/src/GHC/SourceGen/Pat/Internal.hs new file mode 100644 index 0000000..4891f4e --- /dev/null +++ b/src/GHC/SourceGen/Pat/Internal.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE CPP #-} +module GHC.SourceGen.Pat.Internal where + +import HsPat (Pat(..)) +import HsTypes (HsConDetails(..)) + +import GHC.SourceGen.Expr.Internal (litNeedsParen, overLitNeedsParen) +import GHC.SourceGen.Syntax.Internal +import SrcLoc (unLoc) + +-- Note: GHC>=8.6 inserts parentheses automatically when pretty-printing patterns. +-- When we stop supporting lower versions, we may be able to simplify this. +parenthesize :: Pat' -> Pat' +parenthesize p + | needsPar p = parPat p + | otherwise = p + + +needsPar :: Pat' -> Bool +#if MIN_VERSION_ghc(8,6,0) +needsPar (LitPat _ l) = litNeedsParen l +needsPar (NPat _ l _ _) = overLitNeedsParen $ unLoc l +#else +needsPar (LitPat l) = litNeedsParen l +needsPar (NPat l _ _ _) = overLitNeedsParen $ unLoc l +#endif +needsPar (ConPatIn _ (PrefixCon xs)) = not $ null xs +needsPar (ConPatIn _ (InfixCon _ _)) = True +needsPar ConPatOut{} = True +#if MIN_VERSION_ghc(8,6,0) +needsPar SigPat{} = True +#else +needsPar SigPatIn{} = True +needsPar SigPatOut{} = True +#endif +needsPar _ = False + +parPat :: Pat' -> Pat' +parPat = noExt ParPat . builtPat + diff --git a/tests/pprint_test.hs b/tests/pprint_test.hs index c6b54cd..68f9442 100644 --- a/tests/pprint_test.hs +++ b/tests/pprint_test.hs @@ -174,6 +174,13 @@ exprsTest dflags = testGroup "Expr" , "x + y {b = x}" :~ op (var "x") "+" (recordUpd (var "y") [("b", var "x")]) ] + , test "let" + [ "let x = 1 in x" :~ let' [valBind "x" $ int 1] (var "x") + , "let f x = 1 in f" :~ + let' [ funBind "f" $ match [var "x"] $ int 1] (var "f") + , "let f (A x) = 1 in f" :~ + let' [ funBind "f" $ match [conP "A" [var "x"]] $ int 1] (var "f") + ] ] where test = testExprs dflags @@ -222,6 +229,7 @@ declsTest dflags = testGroup "Decls" [ guard (var "x") (var "False") , guard (var "otherwise") (var "True") ] + , "f (A x) = 1" :~ funBind "f" $ match [conP "A" [var "x"]] (int 1) ] , test "tyFamInst" [ "type instance Elt String = Char"