1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 14:13:27 +03:00

Detect redundant patterns (#3101)

* Closes #3008
* Implements the algorithm from [Luc Maranget, Warnings for Pattern
Matching](https://www.cambridge.org/core/services/aop-cambridge-core/content/view/3165B75113781E2431E3856972940347/S0956796807006223a.pdf/warnings-for-pattern-matching.pdf)
to detect redundant patterns.
* Adds an option to the Core pretty printer to print match patterns in a
user-friendly format consistent with pattern syntax in Juvix frontend
language.
This commit is contained in:
Łukasz Czajka 2024-10-30 11:38:22 +01:00 committed by GitHub
parent 23837ed745
commit 68a79bc8a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 301 additions and 38 deletions

View File

@ -17,6 +17,7 @@ data TransformationId
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| DetectRedundantPatterns
| MatchToCase
| EtaExpandApps
| DisambiguateNames
@ -58,10 +59,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId
toTypecheckTransformations :: [TransformationId]
toTypecheckTransformations = [MatchToCase]
toTypecheckTransformations = [DetectRedundantPatterns, MatchToCase]
toStoredTransformations :: [TransformationId]
toStoredTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
toStoredTransformations = [EtaExpandApps, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
combineInfoTablesTransformations :: [TransformationId]
combineInfoTablesTransformations = [CombineInfoTables, FilterUnreachable]
@ -83,6 +84,7 @@ instance TransformationId' TransformationId where
LambdaLetRecLifting -> strLifting
LetRecLifting -> strLetRecLifting
TopEtaExpand -> strTopEtaExpand
DetectRedundantPatterns -> strDetectRedundantPatterns
MatchToCase -> strMatchToCase
EtaExpandApps -> strEtaExpandApps
IdentityTrans -> strIdentity

View File

@ -29,6 +29,9 @@ strLetRecLifting = "letrec-lifting"
strTopEtaExpand :: Text
strTopEtaExpand = "top-eta-expand"
strDetectRedundantPatterns :: Text
strDetectRedundantPatterns = "detect-redundant-patterns"
strMatchToCase :: Text
strMatchToCase = "match-to-case"

View File

@ -279,6 +279,11 @@ isTypeBool = \case
NPrim (TypePrim _ (PrimBool _)) -> True
_ -> False
isUniverse :: Type -> Bool
isUniverse = \case
NUniv {} -> True
_ -> False
-- | `expandType argtys ty` expands the dynamic target of `ty` to match the
-- number of arguments with types specified by `argstys`. For example,
-- `expandType [int, string] (int -> any) = int -> string -> any`.
@ -675,9 +680,19 @@ destruct = \case
concat
[ br
^. matchBranchInfo
: concatMap getPatternInfos (br ^. matchBranchPatterns)
: getSideIfBranchInfos (br ^. matchBranchRhs)
++ concatMap getPatternInfos (br ^. matchBranchPatterns)
| br <- branches
]
getSideIfBranchInfos :: MatchBranchRhs -> [Info]
getSideIfBranchInfos = \case
MatchBranchRhsExpression _ -> []
MatchBranchRhsIfs ifs -> map getSideIfBranchInfos' (toList ifs)
where
getSideIfBranchInfos' :: SideIfBranch -> Info
getSideIfBranchInfos' SideIfBranch {..} = _sideIfBranchInfo
-- sets the infos and the binder types in the patterns
setPatternsInfos :: forall r. (Members '[Input Info, Input Node] r) => NonEmpty Pattern -> Sem r (NonEmpty Pattern)
setPatternsInfos = mapM goPattern

View File

@ -197,6 +197,7 @@ data PatternWildcard' i a = PatternWildcard
data PatternConstr' i a = PatternConstr
{ _patternConstrInfo :: i,
_patternConstrFixity :: Maybe Fixity,
_patternConstrBinder :: Binder' a,
_patternConstrTag :: !Tag,
_patternConstrArgs :: ![Pattern' i a]
@ -549,7 +550,7 @@ instance (Eq a) => Eq (MatchBranch' i a) where
(MatchBranch _ pats1 b1) == (MatchBranch _ pats2 b2) = pats1 == pats2 && b1 == b2
instance (Eq a) => Eq (PatternConstr' i a) where
(PatternConstr _ _ tag1 ps1) == (PatternConstr _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2
(PatternConstr _ _ _ tag1 ps1) == (PatternConstr _ _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2
instance (Eq a) => Eq (SideIfBranch' i a) where
(SideIfBranch _ c1 b1) == (SideIfBranch _ c2 b2) = c1 == c2 && b1 == b2

View File

@ -305,18 +305,53 @@ instance (PrettyCode a) => PrettyCode (If' i a) where
instance PrettyCode PatternWildcard where
ppCode PatternWildcard {..} = do
n <- ppName KNameLocal (_patternWildcardBinder ^. binderName)
bPretty <- asks (^. optPrettyPatterns)
let name = _patternWildcardBinder ^. binderName
if
| not bPretty -> do
n <- ppName KNameLocal name
ppWithType n (_patternWildcardBinder ^. binderType)
| isPrefixOf "_" (fromText name) || name == "?" || name == "" ->
return kwWildcard
| otherwise ->
ppName KNameLocal name
instance PrettyCode PatternConstr where
ppCode PatternConstr {..} = do
n <- ppName KNameConstructor (getInfoName _patternConstrInfo)
bPretty <- asks (^. optPrettyPatterns)
let cname = getInfoName _patternConstrInfo
n <- ppName KNameConstructor cname
bn <- ppName KNameLocal (_patternConstrBinder ^. binderName)
let mkpat :: Doc Ann -> Doc Ann
mkpat pat = if _patternConstrBinder ^. binderName == "?" || _patternConstrBinder ^. binderName == "" then pat else bn <> kwAt <> parens pat
args <- mapM (ppRightExpression appFixity) _patternConstrArgs
let name = fromText (_patternConstrBinder ^. binderName)
mkpat :: Doc Ann -> Doc Ann
mkpat pat = if name == "?" || name == "" || (bPretty && isPrefixOf "_" name) then pat else bn <> kwAt <> parens pat
args0 =
if
| bPretty ->
filter (not . isWildcardTypeBinder) _patternConstrArgs
| otherwise ->
_patternConstrArgs
args <- mapM (ppRightExpression appFixity) args0
let pat = mkpat (hsep (n : args))
if
| bPretty ->
case _patternConstrFixity of
Nothing -> do
return pat
Just fixity
| isBinary fixity ->
goBinary (cname == ",") fixity n args0
| isUnary fixity ->
goUnary fixity n args0
_ -> impossible
| otherwise ->
ppWithType pat (_patternConstrBinder ^. binderType)
where
isWildcardTypeBinder :: Pattern -> Bool
isWildcardTypeBinder = \case
PatWildcard PatternWildcard {..} ->
isUniverse (typeTarget (_patternWildcardBinder ^. binderType))
_ -> False
instance PrettyCode Pattern where
ppCode = \case
@ -683,7 +718,7 @@ instance (PrettyCode a) => PrettyCode [a] where
-- printing values
--------------------------------------------------------------------------------
goBinary :: (Member (Reader Options) r) => Bool -> Fixity -> Doc Ann -> [Value] -> Sem r (Doc Ann)
goBinary :: (HasAtomicity a, PrettyCode a, Member (Reader Options) r) => Bool -> Fixity -> Doc Ann -> [a] -> Sem r (Doc Ann)
goBinary isComma fixity name = \case
[] -> return (parens name)
[arg] -> do
@ -700,7 +735,7 @@ goBinary isComma fixity name = \case
_ ->
impossible
goUnary :: (Member (Reader Options) r) => Fixity -> Doc Ann -> [Value] -> Sem r (Doc Ann)
goUnary :: (HasAtomicity a, PrettyCode a, Member (Reader Options) r) => Fixity -> Doc Ann -> [a] -> Sem r (Doc Ann)
goUnary fixity name = \case
[] -> return (parens name)
[arg] -> do
@ -731,19 +766,22 @@ instance PrettyCode Value where
ValueFun -> return "<function>"
ValueType -> return "<type>"
ppValueSequence :: (Member (Reader Options) r) => [Value] -> Sem r (Doc Ann)
ppValueSequence vs = hsep <$> mapM (ppRightExpression appFixity) vs
docValueSequence :: [Value] -> Doc Ann
docValueSequence =
run
. runReader defaultOptions
. ppValueSequence
--------------------------------------------------------------------------------
-- helper functions
--------------------------------------------------------------------------------
ppSequence ::
(PrettyCode a, HasAtomicity a, Member (Reader Options) r) =>
[a] ->
Sem r (Doc Ann)
ppSequence vs = hsep <$> mapM (ppRightExpression appFixity) vs
docSequence :: (PrettyCode a, HasAtomicity a) => Options -> [a] -> Doc Ann
docSequence opts =
run
. runReader opts
. ppSequence
ppPostExpression ::
(PrettyCode a, HasAtomicity a, Member (Reader Options) r) =>
Fixity ->

View File

@ -5,7 +5,8 @@ import Juvix.Prelude
data Options = Options
{ _optShowIdentIds :: Bool,
_optShowDeBruijnIndices :: Bool,
_optShowArgsNum :: Bool
_optShowArgsNum :: Bool,
_optPrettyPatterns :: Bool
}
makeLenses ''Options
@ -15,7 +16,8 @@ defaultOptions =
Options
{ _optShowIdentIds = False,
_optShowDeBruijnIndices = False,
_optShowArgsNum = False
_optShowArgsNum = False,
_optPrettyPatterns = False
}
traceOptions :: Options
@ -23,7 +25,8 @@ traceOptions =
Options
{ _optShowIdentIds = True,
_optShowDeBruijnIndices = True,
_optShowArgsNum = True
_optShowArgsNum = True,
_optPrettyPatterns = False
}
fromGenericOptions :: GenericOptions -> Options

View File

@ -22,6 +22,7 @@ import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DetectRedundantPatterns
import Juvix.Compiler.Core.Transformation.DisambiguateNames
import Juvix.Compiler.Core.Transformation.Eta
import Juvix.Compiler.Core.Transformation.FoldTypeSynonyms
@ -75,6 +76,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
DetectRedundantPatterns -> mapError (JuvixError @CoreError) . detectRedundantPatterns
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps
DisambiguateNames -> return . disambiguateNames

View File

@ -0,0 +1,129 @@
module Juvix.Compiler.Core.Transformation.DetectRedundantPatterns where
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Pretty hiding (Options)
import Juvix.Compiler.Core.Transformation.Base
type PatternRow = [Pattern]
type PatternMatrix = [PatternRow]
-- | Checks for redundant patterns in `Match` nodes. The algorithm is based on
-- the paper: Luc Maranget, "Warnings for pattern matching", JFP 17 (3):
-- 387421, 2007.
detectRedundantPatterns :: (Members '[Error CoreError, Reader CoreOptions] r) => Module -> Sem r Module
detectRedundantPatterns md = do
fCoverage <- asks (^. optCheckCoverage)
if
| fCoverage ->
mapAllNodesM (umapM (goDetectRedundantPatterns md)) md
| otherwise ->
return md
goDetectRedundantPatterns ::
forall r.
(Members '[Error CoreError, Reader CoreOptions] r) =>
Module ->
Node ->
Sem r Node
goDetectRedundantPatterns md node = case node of
NMatch m -> do
checkMatch m
return node
_ -> return node
where
mockFile = $(mkAbsFile "/check-redundant-patterns")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
checkMatch :: Match -> Sem r ()
checkMatch Match {..} = case _matchBranches of
[] -> return ()
MatchBranch {..} : brs -> go [toList _matchBranchPatterns] brs
where
go :: PatternMatrix -> [MatchBranch] -> Sem r ()
go matrix = \case
[] -> return ()
MatchBranch {..} : branches -> do
let row = toList _matchBranchPatterns
unless (check matrix row) $
throw
CoreError
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat),
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation _matchBranchInfo)
}
case _matchBranchRhs of
MatchBranchRhsExpression {} -> go (row : matrix) branches
MatchBranchRhsIfs {} -> go matrix branches
where
opts = defaultOptions {_optPrettyPatterns = True}
seq = if length _matchBranchPatterns == 1 then "" else " sequence"
pat = if length _matchBranchPatterns == 1 then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)
-- Returns True if vector is useful (not redundant) for matrix, i.e. it is
-- not covered by any row in the matrix. See Definition 6 and Section 3.1 in
-- the paper.
check :: PatternMatrix -> PatternRow -> Bool
check matrix vector = case vector of
[]
| null matrix -> True
| otherwise -> False
(p : ps) -> case p of
PatConstr PatternConstr {..} ->
check
(specialize _patternConstrTag (length _patternConstrArgs) matrix)
(_patternConstrArgs ++ ps)
PatWildcard {} ->
let col = map head' matrix
tagsSet = getPatTags col
tags = toList tagsSet
ind = lookupConstructorInfo md (head' tags) ^. constructorInductive
ctrsNum = length (lookupInductiveInfo md ind ^. inductiveConstructors)
in if
| not (null tags) && length tags == ctrsNum ->
go tags
| otherwise ->
check (computeDefault matrix) ps
where
go :: [Tag] -> Bool
go = \case
[] -> False
(tag : tags') ->
check matrix' (replicate argsNum p ++ ps) || go tags'
where
argsNum = lookupConstructorInfo md tag ^. constructorArgsNum
matrix' = specialize tag argsNum matrix
getPatTags :: [Pattern] -> HashSet Tag
getPatTags = \case
[] ->
mempty
PatConstr PatternConstr {..} : pats ->
HashSet.insert _patternConstrTag (getPatTags pats)
_ : pats ->
getPatTags pats
specialize :: Tag -> Int -> PatternMatrix -> PatternMatrix
specialize tag argsNum = mapMaybe go
where
go :: PatternRow -> Maybe PatternRow
go row = case row of
PatConstr PatternConstr {..} : row'
| _patternConstrTag == tag -> Just $ _patternConstrArgs ++ row'
| otherwise -> Nothing
w@PatWildcard {} : row' ->
Just $ replicate argsNum w ++ row'
[] -> impossible
computeDefault :: PatternMatrix -> PatternMatrix
computeDefault matrix = mapMaybe go matrix
where
go :: PatternRow -> Maybe PatternRow
go row = case row of
PatConstr {} : _ -> Nothing
PatWildcard {} : row' -> Just row'
[] -> impossible

View File

@ -42,6 +42,9 @@ goMatchToCase recur node = case node of
_ ->
recur [] node
where
mockFile = $(mkAbsFile "/match-to-case")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
compileMatch :: Match -> Sem r Node
compileMatch Match {..} =
go 0 (zipExact (toList _matchValues) (toList _matchValueTypes))
@ -101,9 +104,7 @@ goMatchToCase recur node = case node of
where
pat = err (replicate (length vs) ValueWildcard)
seq = if length pat == 1 then "" else "sequence "
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docValueSequence pat
mockFile = $(mkAbsFile "/match-to-case")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docSequence defaultOptions pat
r@PatternRow {..} : matrix'
| all isPatWildcard _patternRowPatterns ->
-- The first row matches all values (Section 4, case 2)
@ -185,16 +186,17 @@ goMatchToCase recur node = case node of
compileMatchingRow err bindersNum vs matrix PatternRow {..} =
case _patternRowRhs of
MatchBranchRhsExpression body ->
goMatchToCase (recur . (bcs ++)) body
goMatchToCase recur' body
MatchBranchRhsIfs ifs -> do
-- If the branch has side-conditions, then we need to continue pattern
-- matching when none of the conditions is satisfied.
body <- compile err bindersNum vs matrix
md <- ask
ifs' <- mapM goSideIfBranch (toList ifs)
let boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive
ifs' = map (\(SideIfBranch i c b) -> (i, c, b)) (toList ifs)
return $ mkIfs boolSym ifs' body
where
recur' = recur . (bcs ++)
bcs =
reverse $
foldl'
@ -204,6 +206,12 @@ goMatchToCase recur node = case node of
_patternRowBinderChangesRev
(drop _patternRowIgnoredPatternsNum (zipExact _patternRowPatterns vs))
goSideIfBranch :: SideIfBranch -> Sem r (Info, Node, Node)
goSideIfBranch SideIfBranch {..} = do
cond <- goMatchToCase recur' _sideIfBranchCondition
body <- goMatchToCase recur' _sideIfBranchBody
return (_sideIfBranchInfo, cond, body)
-- `compileDefault` computes D(M) where `M = col:matrix`, as described in
-- Section 2, Figure 1 in the paper. Then it continues compilation with the
-- new matrix.
@ -238,6 +246,8 @@ goMatchToCase recur node = case node of
compileBranch err bindersNum vs col matrix tag = do
tab <- ask
let ci = lookupConstructorInfo tab tag
-- TODO: this might not work if the constructor has additional type
-- arguments which are not at the front
paramsNum = getTypeParamsNum tab (ci ^. constructorType)
argsNum = length (typeArgs (ci ^. constructorType))
bindersNum' = bindersNum + argsNum

View File

@ -1130,6 +1130,7 @@ fromPatternArg pa = case pa ^. Internal.patternArgName of
PatConstr
PatternConstr
{ _patternConstrInfo = setInfoName (ctrName ^. nameText) mempty,
_patternConstrFixity = ctrName ^. nameFixity,
_patternConstrBinder = binder ctorTy,
_patternConstrTag = tag,
_patternConstrArgs = args
@ -1165,7 +1166,7 @@ goPatternArgs ::
[Internal.PatternArg] ->
[Type] -> -- types of the patterns
Sem r MatchBranch
goPatternArgs lvl0 body = go lvl0 []
goPatternArgs lvl0 body pats0 = go lvl0 [] pats0
where
-- `lvl` is the level of the lambda-bound variable corresponding to the current pattern
go :: Level -> [Pattern] -> [Internal.PatternArg] -> [Type] -> Sem r MatchBranch
@ -1190,7 +1191,8 @@ goPatternArgs lvl0 body = go lvl0 []
impossible
([], []) -> do
body' <- goCaseBranchRhs body
return $ MatchBranch Info.empty (nonEmpty' (reverse pats)) body'
let info = setInfoLocation (getLocSpan (nonEmpty' pats0)) Info.empty
return $ MatchBranch info (nonEmpty' (reverse pats)) body'
_ ->
impossible

View File

@ -1156,7 +1156,16 @@ branchPattern' varsNum vars = do
let info = setInfoName (ci ^. constructorName) Info.empty
ty = fromMaybe mkDynamic' mty
binder = Binder "_" (Just i) ty
return (PatConstr (PatternConstr info binder tag ps), (varsNum', vars'))
pat =
PatConstr
PatternConstr
{ _patternConstrInfo = info,
_patternConstrBinder = binder,
_patternConstrTag = tag,
_patternConstrArgs = ps,
_patternConstrFixity = ci ^. constructorFixity
}
return (pat, (varsNum', vars'))
_ -> do
let vars1 = HashMap.insert txt varsNum vars
mp <- optional (symbolAt >> parens (branchPattern (varsNum + 1) vars1))

View File

@ -57,5 +57,17 @@ tests =
NegTest
"Test007: Pattern matching coverage with side conditions"
$(mkRelDir ".")
$(mkRelFile "test007.juvix")
$(mkRelFile "test007.juvix"),
NegTest
"Test008: Redundant pattern detection"
$(mkRelDir ".")
$(mkRelFile "test008.juvix"),
NegTest
"Test009: Redundant pattern detection with side conditions"
$(mkRelDir ".")
$(mkRelFile "test009.juvix"),
NegTest
"Test010: Redundant pattern detection with complex patterns"
$(mkRelDir ".")
$(mkRelFile "test010.juvix")
]

View File

@ -5,6 +5,6 @@ import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if true := x;
| x :: _ if x > 0 := x;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1,12 @@
-- redundant pattern
module test008;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ := x
| _ := 0;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1,13 @@
-- redundant pattern with side conditions
module test009;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if x > 0 := x
| _ := 0
| nil := 1;
main : Nat := f (1 :: 2 :: nil);

View File

@ -0,0 +1,14 @@
-- Complex redundant pattern
module test010;
import Stdlib.Prelude open;
f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ :: nil := x
| _ :: _ :: _ :: _ := 0
| _ :: _ :: x :: nil := x
| _ :: nil := 1;
main : Nat := f (1 :: 2 :: nil);

View File

@ -9,11 +9,9 @@ type Ord :=
| Lim : ( -> Ord) -> Ord;
addord : Ord -> Ord -> Ord
| Zord y := y
| ZOrd y := y
| (SOrd x) y := SOrd (addord x y)
| (Lim f) y := Lim (aux-addord f y);
aux-addord : ( -> Ord) -> Ord -> -> Ord
| f y z := addord (f z) y;