1
1
mirror of https://github.com/anoma/juvix.git synced 2024-12-02 10:47:32 +03:00

Detect constant side conditions in matches (#3133)

* Closes #3007 
* Depends on #3101 
* Detects side conditions which are `true` (removes the condition) or
`false` (removes the branch).
This commit is contained in:
Łukasz Czajka 2024-11-01 10:50:19 +01:00 committed by GitHub
parent 68a79bc8a8
commit 95275ca5c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 148 additions and 12 deletions

View File

@ -17,6 +17,7 @@ data TransformationId
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| DetectConstantSideConditions
| DetectRedundantPatterns
| MatchToCase
| EtaExpandApps
@ -59,10 +60,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId
toTypecheckTransformations :: [TransformationId]
toTypecheckTransformations = [DetectRedundantPatterns, MatchToCase]
toTypecheckTransformations = [DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase]
toStoredTransformations :: [TransformationId]
toStoredTransformations = [EtaExpandApps, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
toStoredTransformations = [EtaExpandApps, DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
combineInfoTablesTransformations :: [TransformationId]
combineInfoTablesTransformations = [CombineInfoTables, FilterUnreachable]
@ -84,6 +85,7 @@ instance TransformationId' TransformationId where
LambdaLetRecLifting -> strLifting
LetRecLifting -> strLetRecLifting
TopEtaExpand -> strTopEtaExpand
DetectConstantSideConditions -> strDetectConstantSideConditions
DetectRedundantPatterns -> strDetectRedundantPatterns
MatchToCase -> strMatchToCase
EtaExpandApps -> strEtaExpandApps

View File

@ -29,6 +29,9 @@ strLetRecLifting = "letrec-lifting"
strTopEtaExpand :: Text
strTopEtaExpand = "top-eta-expand"
strDetectConstantSideConditions :: Text
strDetectConstantSideConditions = "detect-constant-side-conditions"
strDetectRedundantPatterns :: Text
strDetectRedundantPatterns = "detect-redundant-patterns"

View File

@ -22,6 +22,7 @@ import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DetectConstantSideConditions
import Juvix.Compiler.Core.Transformation.DetectRedundantPatterns
import Juvix.Compiler.Core.Transformation.DisambiguateNames
import Juvix.Compiler.Core.Transformation.Eta
@ -76,6 +77,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
DetectConstantSideConditions -> mapError (JuvixError @CoreError) . detectConstantSideConditions
DetectRedundantPatterns -> mapError (JuvixError @CoreError) . detectRedundantPatterns
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps

View File

@ -12,7 +12,6 @@ module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where
-- ```
-- let z := f x in case z of { c y := y + x; d y := y }
-- ```
-- This transformation is needed for the Nockma backend.
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra

View File

@ -0,0 +1,49 @@
module Juvix.Compiler.Core.Transformation.DetectConstantSideConditions
( detectConstantSideConditions,
)
where
import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
detectConstantSideConditions :: forall r. (Members '[Error CoreError, Reader CoreOptions] r) => Module -> Sem r Module
detectConstantSideConditions md = mapAllNodesM (umapM go) md
where
mockFile = $(mkAbsFile "/detect-constant-side-conditions")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive
go :: Node -> Sem r Node
go node = case node of
NMatch m -> NMatch <$> (overM matchBranches (mapMaybeM convertMatchBranch) m)
_ -> return node
convertMatchBranch :: MatchBranch -> Sem r (Maybe MatchBranch)
convertMatchBranch br@MatchBranch {..} =
case _matchBranchRhs of
MatchBranchRhsExpression {} ->
return $ Just br
MatchBranchRhsIfs ifs ->
case ifs1 of
[] ->
case nonEmpty ifs0 of
Nothing -> return Nothing
Just ifs0' -> return $ Just $ set matchBranchRhs (MatchBranchRhsIfs ifs0') br
SideIfBranch {..} : ifs1' -> do
fCoverage <- asks (^. optCheckCoverage)
unless (not fCoverage || null ifs1') $
throw
CoreError
{ _coreErrorMsg = "Redundant side condition",
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation (head' ifs1' ^. sideIfBranchInfo))
}
let ifsBody = mkIfs boolSym (map (\(SideIfBranch i c b) -> (i, c, b)) ifs0) _sideIfBranchBody
return $ Just $ set matchBranchRhs (MatchBranchRhsExpression ifsBody) br
where
ifs' = filter (not . isFalseConstr . (^. sideIfBranchCondition)) (toList ifs)
(ifs0, ifs1) = span (not . isTrueConstr . (^. sideIfBranchCondition)) ifs'

View File

@ -36,7 +36,7 @@ goDetectRedundantPatterns md node = case node of
return node
_ -> return node
where
mockFile = $(mkAbsFile "/check-redundant-patterns")
mockFile = $(mkAbsFile "/detect-redundant-patterns")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
checkMatch :: Match -> Sem r ()
@ -52,7 +52,7 @@ goDetectRedundantPatterns md node = case node of
unless (check matrix row) $
throw
CoreError
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat),
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat <> "\nPerhaps you mistyped a constructor name in an earlier pattern?"),
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation _matchBranchInfo)
}
@ -61,8 +61,8 @@ goDetectRedundantPatterns md node = case node of
MatchBranchRhsIfs {} -> go matrix branches
where
opts = defaultOptions {_optPrettyPatterns = True}
seq = if length _matchBranchPatterns == 1 then "" else " sequence"
pat = if length _matchBranchPatterns == 1 then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)
seq = if isSingleton (toList _matchBranchPatterns) then "" else " sequence"
pat = if isSingleton (toList _matchBranchPatterns) then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)
-- Returns True if vector is useful (not redundant) for matrix, i.e. it is
-- not covered by any row in the matrix. See Definition 6 and Section 3.1 in

View File

@ -103,8 +103,8 @@ goMatchToCase recur node = case node of
mkBuiltinApp' OpFail [mkConstant' (ConstString ("Pattern sequence not matched: " <> ppTrace pat))]
where
pat = err (replicate (length vs) ValueWildcard)
seq = if length pat == 1 then "" else "sequence "
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docSequence defaultOptions pat
seq = if isSingleton pat then "" else "sequence "
pat' = if isSingleton pat then doc defaultOptions (head' pat) else docSequence defaultOptions pat
r@PatternRow {..} : matrix'
| all isPatWildcard _patternRowPatterns ->
-- The first row matches all values (Section 4, case 2)

View File

@ -51,6 +51,7 @@ module Juvix.Prelude.Base.Foundation
module GHC.Generics,
module GHC.Num,
module GHC.Real,
module GHC.Utils.Misc,
module Control.Lens,
module Language.Haskell.TH.Syntax,
module Prettyprinter,
@ -197,6 +198,7 @@ import GHC.Generics (Generic)
import GHC.Num
import GHC.Real
import GHC.Stack.Types
import GHC.Utils.Misc (isSingleton)
import Language.Haskell.TH.Syntax (Exp, Lift, Q)
import Numeric hiding (exp, log, pi)
import Path (Abs, Dir, File, Path, Rel, SomeBase (..))

View File

@ -69,5 +69,17 @@ tests =
NegTest
"Test010: Redundant pattern detection with complex patterns"
$(mkRelDir ".")
$(mkRelFile "test010.juvix")
$(mkRelFile "test010.juvix"),
NegTest
"Test011: Redundant pattern detection with side conditions"
$(mkRelDir ".")
$(mkRelFile "test011.juvix"),
NegTest
"Test012: Pattern matching coverage with side conditions"
$(mkRelDir ".")
$(mkRelFile "test012.juvix"),
NegTest
"Test013: Redundant side condition detection"
$(mkRelDir ".")
$(mkRelFile "test013.juvix")
]

View File

@ -480,5 +480,10 @@ tests =
"Test081: Non-duplication in let-folding"
$(mkRelDir ".")
$(mkRelFile "test081.juvix")
$(mkRelFile "out/test081.out")
$(mkRelFile "out/test081.out"),
posTest
"Test082: Pattern matching with side conditions"
$(mkRelDir ".")
$(mkRelFile "test082.juvix")
$(mkRelFile "out/test082.out")
]

View File

@ -9,4 +9,3 @@ f : List Nat -> List Nat -> Nat
| _ nil := 0;
main : Nat := f (1 :: nil) (2 :: nil);

View File

@ -0,0 +1,15 @@
-- Redundant pattern after a true side condition
module test011;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ :: nil := x
| _ :: _ :: _ :: _ if true := 0
| _ :: _ :: x :: nil := x
| _ :: nil := 1
| _ := 2;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1,11 @@
-- Non-exhaustive pattern matching with false side conditions
module test012;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if false := x;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1,14 @@
-- Redundant side condition
module test013;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if x > 0 := x
| if true := 0
| if false := 1
| if x == 0 := 2;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1 @@
4

View File

@ -0,0 +1,22 @@
-- Pattern matching with side conditions
module test082;
import Stdlib.Prelude open;
f (lst : List Nat) : Nat :=
case lst of
| [] := 0
| x :: xs
| if x == 0 := 1
| if true := 2;
g (lst : List Nat) : Nat :=
case lst of
| [] := 0
| _ :: _ if false := 0
| x :: xs
| if x == 0 := 1
| if false := 2
| if true := 3;
main : Nat := f [0; 1; 2] + g [1; 2];