From 95275ca5c1f98607be7ee9f05e6bf744157090d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:50:19 +0100 Subject: [PATCH] Detect constant side conditions in matches (#3133) * Closes #3007 * Depends on #3101 * Detects side conditions which are `true` (removes the condition) or `false` (removes the branch). --- .../Compiler/Core/Data/TransformationId.hs | 6 ++- .../Core/Data/TransformationId/Strings.hs | 3 ++ src/Juvix/Compiler/Core/Transformation.hs | 2 + .../Core/Transformation/ComputeCaseANF.hs | 1 - .../DetectConstantSideConditions.hs | 49 +++++++++++++++++++ .../Transformation/DetectRedundantPatterns.hs | 8 +-- .../Core/Transformation/MatchToCase.hs | 4 +- src/Juvix/Prelude/Base/Foundation.hs | 2 + test/Compilation/Negative.hs | 14 +++++- test/Compilation/Positive.hs | 7 ++- tests/Compilation/negative/test001.juvix | 1 - tests/Compilation/negative/test011.juvix | 15 ++++++ tests/Compilation/negative/test012.juvix | 11 +++++ tests/Compilation/negative/test013.juvix | 14 ++++++ tests/Compilation/positive/out/test082.out | 1 + tests/Compilation/positive/test082.juvix | 22 +++++++++ 16 files changed, 148 insertions(+), 12 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Transformation/DetectConstantSideConditions.hs create mode 100644 tests/Compilation/negative/test011.juvix create mode 100644 tests/Compilation/negative/test012.juvix create mode 100644 tests/Compilation/negative/test013.juvix create mode 100644 tests/Compilation/positive/out/test082.out create mode 100644 tests/Compilation/positive/test082.juvix diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 7e4ec2075..a951585da 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -17,6 +17,7 @@ data TransformationId | UnrollRecursion | ComputeTypeInfo | ComputeCaseANF + | DetectConstantSideConditions | DetectRedundantPatterns | MatchToCase | EtaExpandApps @@ -59,10 +60,10 @@ data PipelineId type TransformationLikeId = TransformationLikeId' TransformationId PipelineId toTypecheckTransformations :: [TransformationId] -toTypecheckTransformations = [DetectRedundantPatterns, MatchToCase] +toTypecheckTransformations = [DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase] toStoredTransformations :: [TransformationId] -toStoredTransformations = [EtaExpandApps, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames] +toStoredTransformations = [EtaExpandApps, DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames] combineInfoTablesTransformations :: [TransformationId] combineInfoTablesTransformations = [CombineInfoTables, FilterUnreachable] @@ -84,6 +85,7 @@ instance TransformationId' TransformationId where LambdaLetRecLifting -> strLifting LetRecLifting -> strLetRecLifting TopEtaExpand -> strTopEtaExpand + DetectConstantSideConditions -> strDetectConstantSideConditions DetectRedundantPatterns -> strDetectRedundantPatterns MatchToCase -> strMatchToCase EtaExpandApps -> strEtaExpandApps diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index 0a6de2ba4..8b7f08824 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -29,6 +29,9 @@ strLetRecLifting = "letrec-lifting" strTopEtaExpand :: Text strTopEtaExpand = "top-eta-expand" +strDetectConstantSideConditions :: Text +strDetectConstantSideConditions = "detect-constant-side-conditions" + strDetectRedundantPatterns :: Text strDetectRedundantPatterns = "detect-redundant-patterns" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index 084647b8b..434ef8dba 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -22,6 +22,7 @@ import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables) import Juvix.Compiler.Core.Transformation.ComputeCaseANF import Juvix.Compiler.Core.Transformation.ComputeTypeInfo import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes +import Juvix.Compiler.Core.Transformation.DetectConstantSideConditions import Juvix.Compiler.Core.Transformation.DetectRedundantPatterns import Juvix.Compiler.Core.Transformation.DisambiguateNames import Juvix.Compiler.Core.Transformation.Eta @@ -76,6 +77,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts ComputeTypeInfo -> return . computeTypeInfo ComputeCaseANF -> return . computeCaseANF UnrollRecursion -> unrollRecursion + DetectConstantSideConditions -> mapError (JuvixError @CoreError) . detectConstantSideConditions DetectRedundantPatterns -> mapError (JuvixError @CoreError) . detectRedundantPatterns MatchToCase -> mapError (JuvixError @CoreError) . matchToCase EtaExpandApps -> return . etaExpansionApps diff --git a/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs b/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs index fae9cba46..d80207365 100644 --- a/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs +++ b/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs @@ -12,7 +12,6 @@ module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where -- ``` -- let z := f x in case z of { c y := y + x; d y := y } -- ``` --- This transformation is needed for the Nockma backend. import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Extra diff --git a/src/Juvix/Compiler/Core/Transformation/DetectConstantSideConditions.hs b/src/Juvix/Compiler/Core/Transformation/DetectConstantSideConditions.hs new file mode 100644 index 000000000..aa0819367 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/DetectConstantSideConditions.hs @@ -0,0 +1,49 @@ +module Juvix.Compiler.Core.Transformation.DetectConstantSideConditions + ( detectConstantSideConditions, + ) +where + +import Juvix.Compiler.Core.Error +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info.LocationInfo +import Juvix.Compiler.Core.Options +import Juvix.Compiler.Core.Transformation.Base + +detectConstantSideConditions :: forall r. (Members '[Error CoreError, Reader CoreOptions] r) => Module -> Sem r Module +detectConstantSideConditions md = mapAllNodesM (umapM go) md + where + mockFile = $(mkAbsFile "/detect-constant-side-conditions") + defaultLoc = singletonInterval (mkInitialLoc mockFile) + + boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive + + go :: Node -> Sem r Node + go node = case node of + NMatch m -> NMatch <$> (overM matchBranches (mapMaybeM convertMatchBranch) m) + _ -> return node + + convertMatchBranch :: MatchBranch -> Sem r (Maybe MatchBranch) + convertMatchBranch br@MatchBranch {..} = + case _matchBranchRhs of + MatchBranchRhsExpression {} -> + return $ Just br + MatchBranchRhsIfs ifs -> + case ifs1 of + [] -> + case nonEmpty ifs0 of + Nothing -> return Nothing + Just ifs0' -> return $ Just $ set matchBranchRhs (MatchBranchRhsIfs ifs0') br + SideIfBranch {..} : ifs1' -> do + fCoverage <- asks (^. optCheckCoverage) + unless (not fCoverage || null ifs1') $ + throw + CoreError + { _coreErrorMsg = "Redundant side condition", + _coreErrorNode = Nothing, + _coreErrorLoc = fromMaybe defaultLoc (getInfoLocation (head' ifs1' ^. sideIfBranchInfo)) + } + let ifsBody = mkIfs boolSym (map (\(SideIfBranch i c b) -> (i, c, b)) ifs0) _sideIfBranchBody + return $ Just $ set matchBranchRhs (MatchBranchRhsExpression ifsBody) br + where + ifs' = filter (not . isFalseConstr . (^. sideIfBranchCondition)) (toList ifs) + (ifs0, ifs1) = span (not . isTrueConstr . (^. sideIfBranchCondition)) ifs' diff --git a/src/Juvix/Compiler/Core/Transformation/DetectRedundantPatterns.hs b/src/Juvix/Compiler/Core/Transformation/DetectRedundantPatterns.hs index f52faf8c1..42b737ef7 100644 --- a/src/Juvix/Compiler/Core/Transformation/DetectRedundantPatterns.hs +++ b/src/Juvix/Compiler/Core/Transformation/DetectRedundantPatterns.hs @@ -36,7 +36,7 @@ goDetectRedundantPatterns md node = case node of return node _ -> return node where - mockFile = $(mkAbsFile "/check-redundant-patterns") + mockFile = $(mkAbsFile "/detect-redundant-patterns") defaultLoc = singletonInterval (mkInitialLoc mockFile) checkMatch :: Match -> Sem r () @@ -52,7 +52,7 @@ goDetectRedundantPatterns md node = case node of unless (check matrix row) $ throw CoreError - { _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat), + { _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat <> "\nPerhaps you mistyped a constructor name in an earlier pattern?"), _coreErrorNode = Nothing, _coreErrorLoc = fromMaybe defaultLoc (getInfoLocation _matchBranchInfo) } @@ -61,8 +61,8 @@ goDetectRedundantPatterns md node = case node of MatchBranchRhsIfs {} -> go matrix branches where opts = defaultOptions {_optPrettyPatterns = True} - seq = if length _matchBranchPatterns == 1 then "" else " sequence" - pat = if length _matchBranchPatterns == 1 then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns) + seq = if isSingleton (toList _matchBranchPatterns) then "" else " sequence" + pat = if isSingleton (toList _matchBranchPatterns) then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns) -- Returns True if vector is useful (not redundant) for matrix, i.e. it is -- not covered by any row in the matrix. See Definition 6 and Section 3.1 in diff --git a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs index 926a009b3..e8166534b 100644 --- a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs +++ b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs @@ -103,8 +103,8 @@ goMatchToCase recur node = case node of mkBuiltinApp' OpFail [mkConstant' (ConstString ("Pattern sequence not matched: " <> ppTrace pat))] where pat = err (replicate (length vs) ValueWildcard) - seq = if length pat == 1 then "" else "sequence " - pat' = if length pat == 1 then doc defaultOptions (head' pat) else docSequence defaultOptions pat + seq = if isSingleton pat then "" else "sequence " + pat' = if isSingleton pat then doc defaultOptions (head' pat) else docSequence defaultOptions pat r@PatternRow {..} : matrix' | all isPatWildcard _patternRowPatterns -> -- The first row matches all values (Section 4, case 2) diff --git a/src/Juvix/Prelude/Base/Foundation.hs b/src/Juvix/Prelude/Base/Foundation.hs index 630dd4184..2698916a2 100644 --- a/src/Juvix/Prelude/Base/Foundation.hs +++ b/src/Juvix/Prelude/Base/Foundation.hs @@ -51,6 +51,7 @@ module Juvix.Prelude.Base.Foundation module GHC.Generics, module GHC.Num, module GHC.Real, + module GHC.Utils.Misc, module Control.Lens, module Language.Haskell.TH.Syntax, module Prettyprinter, @@ -197,6 +198,7 @@ import GHC.Generics (Generic) import GHC.Num import GHC.Real import GHC.Stack.Types +import GHC.Utils.Misc (isSingleton) import Language.Haskell.TH.Syntax (Exp, Lift, Q) import Numeric hiding (exp, log, pi) import Path (Abs, Dir, File, Path, Rel, SomeBase (..)) diff --git a/test/Compilation/Negative.hs b/test/Compilation/Negative.hs index 3a209b4c8..f0def9172 100644 --- a/test/Compilation/Negative.hs +++ b/test/Compilation/Negative.hs @@ -69,5 +69,17 @@ tests = NegTest "Test010: Redundant pattern detection with complex patterns" $(mkRelDir ".") - $(mkRelFile "test010.juvix") + $(mkRelFile "test010.juvix"), + NegTest + "Test011: Redundant pattern detection with side conditions" + $(mkRelDir ".") + $(mkRelFile "test011.juvix"), + NegTest + "Test012: Pattern matching coverage with side conditions" + $(mkRelDir ".") + $(mkRelFile "test012.juvix"), + NegTest + "Test013: Redundant side condition detection" + $(mkRelDir ".") + $(mkRelFile "test013.juvix") ] diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index a7edbb72c..0885b6f9a 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -480,5 +480,10 @@ tests = "Test081: Non-duplication in let-folding" $(mkRelDir ".") $(mkRelFile "test081.juvix") - $(mkRelFile "out/test081.out") + $(mkRelFile "out/test081.out"), + posTest + "Test082: Pattern matching with side conditions" + $(mkRelDir ".") + $(mkRelFile "test082.juvix") + $(mkRelFile "out/test082.out") ] diff --git a/tests/Compilation/negative/test001.juvix b/tests/Compilation/negative/test001.juvix index 2c32aaebc..0c6a57ee7 100644 --- a/tests/Compilation/negative/test001.juvix +++ b/tests/Compilation/negative/test001.juvix @@ -9,4 +9,3 @@ f : List Nat -> List Nat -> Nat | _ nil := 0; main : Nat := f (1 :: nil) (2 :: nil); - diff --git a/tests/Compilation/negative/test011.juvix b/tests/Compilation/negative/test011.juvix new file mode 100644 index 000000000..894cfdffa --- /dev/null +++ b/tests/Compilation/negative/test011.juvix @@ -0,0 +1,15 @@ +-- Redundant pattern after a true side condition +module test011; + +import Stdlib.Prelude open; + +f (x : List Nat) : Nat := + case x of + | nil := 0 + | x :: _ :: nil := x + | _ :: _ :: _ :: _ if true := 0 + | _ :: _ :: x :: nil := x + | _ :: nil := 1 + | _ := 2; + +main : Nat := f (1 :: 2 :: nil); diff --git a/tests/Compilation/negative/test012.juvix b/tests/Compilation/negative/test012.juvix new file mode 100644 index 000000000..d02d5fa1c --- /dev/null +++ b/tests/Compilation/negative/test012.juvix @@ -0,0 +1,11 @@ +-- Non-exhaustive pattern matching with false side conditions +module test012; + +import Stdlib.Prelude open; + +f (x : List Nat) : Nat := + case x of + | nil := 0 + | x :: _ if false := x; + +main : Nat := f (1 :: 2 :: nil); diff --git a/tests/Compilation/negative/test013.juvix b/tests/Compilation/negative/test013.juvix new file mode 100644 index 000000000..c17aa03e5 --- /dev/null +++ b/tests/Compilation/negative/test013.juvix @@ -0,0 +1,14 @@ +-- Redundant side condition +module test013; + +import Stdlib.Prelude open; + +f (x : List Nat) : Nat := + case x of + | nil := 0 + | x :: _ if x > 0 := x + | if true := 0 + | if false := 1 + | if x == 0 := 2; + +main : Nat := f (1 :: 2 :: nil); diff --git a/tests/Compilation/positive/out/test082.out b/tests/Compilation/positive/out/test082.out new file mode 100644 index 000000000..b8626c4cf --- /dev/null +++ b/tests/Compilation/positive/out/test082.out @@ -0,0 +1 @@ +4 diff --git a/tests/Compilation/positive/test082.juvix b/tests/Compilation/positive/test082.juvix new file mode 100644 index 000000000..7bbfb8c2a --- /dev/null +++ b/tests/Compilation/positive/test082.juvix @@ -0,0 +1,22 @@ +-- Pattern matching with side conditions +module test082; + +import Stdlib.Prelude open; + +f (lst : List Nat) : Nat := + case lst of + | [] := 0 + | x :: xs + | if x == 0 := 1 + | if true := 2; + +g (lst : List Nat) : Nat := + case lst of + | [] := 0 + | _ :: _ if false := 0 + | x :: xs + | if x == 0 := 1 + | if false := 2 + | if true := 3; + +main : Nat := f [0; 1; 2] + g [1; 2];