Rework pattern compilation significantly

- Avoid having to figure out when we need to rebind variables for
  VarP cases by pre-binding scrutinees and only casing on variables.
- Make use of constructor information for compiling data, so that
  cases are filled out based on all constructors. This fixes previous
  incorrect behavior with respect to how variable/default cases were
  being handled
- Built-in special patterns work similarly to the old methodology, but
  don't have the complications of data matching (aside from sequences,
  which are still TBD).
This commit is contained in:
Dan Doel 2020-06-17 11:02:31 -04:00
parent 33eb2a8d2b
commit 0f4cb488b2
3 changed files with 276 additions and 153 deletions

View File

@ -19,6 +19,7 @@ import qualified Data.Map.Strict as Map
import qualified Unison.Term as Tm
import Unison.Var (Var)
import Unison.DataDeclaration (constructorFields,asDataDecl)
import qualified Unison.LabeledDependency as RF
import Unison.Reference (Reference)
import qualified Unison.Reference as RF
@ -45,6 +46,7 @@ data EvalCtx v
, refTy :: Map.Map RF.Reference RTag
, refTm :: Map.Map RF.Reference Word64
, combs :: EnumMap Word64 Comb
, dspec :: DataSpec
, backrefTy :: EnumMap RTag RF.Reference
, backrefTm :: EnumMap Word64 (Term v)
}
@ -62,6 +64,7 @@ baseContext
, refTy = builtinTypeNumbering
, refTm = builtinTermNumbering
, combs = emitComb @v mempty <$> numberedTermLookup
, dspec = builtinDataSpec
, backrefTy = builtinTypeBackref
, backrefTm = Tm.ref () <$> builtinTermBackref
}
@ -83,13 +86,15 @@ baseContext
allocType
:: EvalCtx v
-> RF.Reference
-> [[RF.Reference]]
-> IO (EvalCtx v)
allocType _ b@(RF.Builtin _)
allocType _ b@(RF.Builtin _) _
= die $ "Unknown builtin type reference: " ++ show b
allocType ctx r
allocType ctx r cons
= pure $ ctx
{ refTy = Map.insert r frsh $ refTy ctx
, backrefTy = mapInsert frsh r $ backrefTy ctx
, dspec = Map.insert r cons $ dspec ctx
, freshTy = freshTy ctx + 1
}
where
@ -99,12 +104,17 @@ collectDeps
:: Var v
=> CodeLookup v IO ()
-> Term v
-> IO ([Reference], [Reference])
collectDeps _ tm
= pure $ foldr categorize ([],[]) chld
-> IO ([(Reference,[[Reference]])], [Reference])
collectDeps cl tm
= (,tms) <$> traverse getDecl tys
where
chld = toList $ Tm.labeledDependencies tm
categorize = either (first . (:)) (second . (:)) . RF.toReference
(tys, tms) = foldr categorize ([],[]) chld
getDecl ty@(RF.DerivedId i) =
(ty,) . maybe [] (constructorFields . asDataDecl)
<$> getTypeDeclaration cl i
getDecl r = pure (r,[])
loadDeps
:: Var v
@ -115,7 +125,8 @@ loadDeps
loadDeps cl ctx tm = do
(tys, _ ) <- collectDeps cl tm
-- TODO: terms
foldM allocType ctx $ filter (`Map.notMember`refTy ctx) tys
foldM (uncurry . allocType) ctx
$ filter (\(r,_) -> r `Map.notMember` refTy ctx) tys
addCombs :: EnumMap Word64 Comb -> EvalCtx v -> EvalCtx v
addCombs m ctx = ctx { combs = m <> combs ctx }
@ -138,7 +149,7 @@ compileTerm w tm ctx
. emitCombs frsh
. superNormalize (ref $ refTm ctx) (ref $ refTy ctx)
. lamLift
. splitPatterns
. splitPatterns (dspec ctx)
$ tm
where
frsh = freshTm ctx

View File

@ -1,79 +1,69 @@
{-# language BangPatterns #-}
{-# language ViewPatterns #-}
{-# language PatternGuards #-}
{-# language TupleSections #-}
{-# language PatternSynonyms #-}
module Unison.Runtime.Pattern
( splitPatterns
( DataSpec
, splitPatterns
, builtinDataSpec
) where
import Control.Applicative ((<|>))
import Control.Lens ((<&>))
import Control.Monad.State (State, state, runState, modify)
import Data.Bifunctor (bimap)
import Data.List (splitAt, findIndex)
import Data.Maybe (catMaybes)
import Data.Set as Set (Set, insert, fromList)
import Data.Set as Set (Set, insert, fromList, member)
import Unison.ABT
(absChain', freshIn, visitPure, pattern AbsN', changeVars)
-- import qualified Unison.ABT as ABT
import Unison.Builtin.Decls (builtinDataDecls)
import Unison.DataDeclaration (constructorFields)
import Unison.Pattern
-- import Unison.Reference (Reference)
import Unison.Term
import Unison.Reference (Reference(..))
import Unison.Symbol (Symbol)
import Unison.Term hiding (Term)
import qualified Unison.Term as Tm
import Unison.Var (Var, typed, pattern Pattern)
import Data.Map.Strict as Map
(Map, fromListWith, lookup, toList, insertWith)
import qualified Unison.Type as Rf
-- newtype DataSpec = DS (Map Reference [Int])
import Data.Map.Strict
(Map, toList, fromListWith, insertWith)
import qualified Data.Map.Strict as Map
type PatternV a v = PatternP (a, v)
type Term v = Tm.Term v ()
type Cons = [[Reference]]
data PatternRow v a
type DataSpec = Map Reference Cons
type PatternV v = PatternP v
type Scrut v = (v, Reference)
data PatternRow v
= PR
{ _pats :: [PatternV a v]
, guard :: Maybe (Term v a)
, body :: Term v a
{ _pats :: [PatternV v]
, guard :: Maybe (Term v)
, body :: Term v
}
rebind
:: Semigroup a
=> Var v
=> [Term v a]
-> PatternRow v a
-> Term v a
-> Term v a
rebind tms0 (PR ps0 _ _) = let1' False . reverse $ collect [] tms0 ps0
where
collect acc (Var' u : tms) (VarP (_, v) : ps)
| u == v = collect acc tms ps
| otherwise
= error $ "pattern rebind: mismatched variables: " ++ show (u,v)
collect acc (tm : tms) (VarP (_, v) : ps)
= collect ((v, tm) : acc) tms ps
collect acc [] [] = acc
collect _ tms@(_:_) []
= error $ "pattern rebind: more terms than patterns: " ++ show tms
collect _ [] ps@(_:_)
= error $ "pattern rebind: more patterns than terms: " ++ show ps
collect _ (_:_) (p:_)
= error $ "pattern rebind: unsplit pattern" ++ show p
builtinDataSpec :: DataSpec
builtinDataSpec
= Map.fromList
$ bimap DerivedId constructorFields . (\(_,x,y) -> (x,y))
<$> builtinDataDecls @Symbol
-- collectRowVars
-- :: Var v
-- => [PatternV a v]
-- -> Maybe (Term v a)
-- -> Term v a
-- -> Set v
-- collectRowVars ps g b
-- = (foldMap.foldMap.foldMap) singleton ps
-- <> foldMap freeVars g <> freeVars b
data PatternMatrix v
= PM { _rows :: [PatternRow v] }
data PatternMatrix v a
= PM { _rows :: [PatternRow v a] }
type Heuristic v = PatternMatrix v -> Maybe Int
type Heuristic v a = PatternMatrix v a -> Maybe Int
choose :: [Heuristic v a] -> PatternMatrix v a -> Int
choose :: [Heuristic v] -> PatternMatrix v -> Int
choose [] _ = 0
choose (h:hs) m
| Just i <- h m = i
@ -84,74 +74,136 @@ refutable (UnboundP _) = False
refutable (VarP _) = False
refutable _ = True
rowIrrefutable :: PatternRow v a -> Bool
rowIrrefutable :: PatternRow v -> Bool
rowIrrefutable (PR ps _ _) = all (not.refutable) ps
firstRow :: ([PatternV a v] -> Maybe Int) -> Heuristic v a
firstRow :: ([PatternV v] -> Maybe Int) -> Heuristic v
firstRow f (PM (r:_)) = f $ _pats r
firstRow _ _ = Nothing
heuristics :: [Heuristic v a]
heuristics :: [Heuristic v]
heuristics = [firstRow $ findIndex refutable]
extractVar :: Var v => PatternV a v -> ((a, v), PatternV a v)
extractVar p = (loc p, p)
extractVar
:: Var v
=> Reference
-> PatternV v
-> (Maybe (v, Reference), PatternV v)
extractVar r p
| UnboundP{} <- p = (Nothing, p)
| otherwise = (Just (loc p, r), p)
decomposePattern :: Var v => PatternV a v -> [((a, v), PatternV a v)]
decomposePattern (ConstructorP _ _ _ ps)
= extractVar <$> ps
decomposePattern (EffectBindP _ _ _ ps pk)
= fmap extractVar $ ps ++ [pk]
decomposePattern (EffectPureP _ p)
= [extractVar p]
decomposePattern (SequenceLiteralP _ _)
zipWithExact :: String -> (a -> b -> c) -> [a] -> [b] -> [c]
zipWithExact _ _ [] [] = []
zipWithExact s f (x:xs) (y:ys)
= let !zs = zipWithExact s f xs ys in f x y : zs
zipWithExact s _ _ _ = error s
decomposePattern
:: Var v
=> Int -> [Reference] -> PatternV v
-> [[(Maybe (v,Reference), PatternV v)]]
decomposePattern t flds p@(ConstructorP _ _ u ps)
| t == u
= [zipWithExact err extractVar flds ps]
where
err = "decomposePattern: mismatched constructor fields: "
++ show (flds, p)
decomposePattern t flds p@(EffectBindP _ _ u ps pk)
| t == u
= [zipWithExact err extractVar flds $ ps ++ [pk]]
where
err = "decomposePattern: mismatched ability fields: "
++ show (flds, p)
decomposePattern t flds (EffectPureP _ p)
| t == -1
= [zipWithExact err extractVar flds [p]]
where
err = "decomposePattern: wrong number of fields for effect-pure: "
++ show flds
decomposePattern _ flds (VarP _)
= [(Nothing, UnboundP (typed Pattern)) <$ flds]
decomposePattern _ flds (UnboundP _)
= [(Nothing, UnboundP (typed Pattern)) <$ flds]
decomposePattern _ _ (SequenceLiteralP _ _)
= error "decomposePattern: sequence literal"
decomposePattern _ = []
decomposePattern _ _ _ = []
matchBuiltin :: PatternP a -> Maybe (PatternP ())
matchBuiltin (VarP _) = Just $ VarP ()
matchBuiltin (UnboundP _) = Just $ UnboundP ()
matchBuiltin (NatP _ n) = Just $ NatP () n
matchBuiltin (IntP _ n) = Just $ IntP () n
matchBuiltin (TextP _ t) = Just $ TextP () t
matchBuiltin (CharP _ c) = Just $ CharP () c
matchBuiltin (FloatP _ d) = Just $ FloatP () d
matchBuiltin _ = Nothing
splitRow
:: Var v
=> Int
-> PatternRow v a
-> (PatternP a, [([(a,v)], PatternRow v a)])
splitRow i (PR (splitAt i -> (pl, sp : pr)) g b)
| VarP av@(a,_) <- sp
= (VarP a, [([av], PR (pl ++ pr) g b)])
| otherwise
= (fst <$> sp, [(vars, PR (pl ++ subs ++ pr) g b)])
where
(vars, subs) = unzip $ decomposePattern sp
splitRow _ _ = error "splitRow: bad index"
-> Int
-> [Reference]
-> PatternRow v
-> [([(v,Reference)], PatternRow v)]
splitRow i t flds (PR (splitAt i -> (pl, sp : pr)) g b)
= bimap catMaybes (\subs -> PR (pl ++ subs ++ pr) g b)
. unzip <$> decomposePattern t flds sp
splitRow _ _ _ _ = error "splitRow: bad index"
renameRow :: Var v => Map v v -> PatternRow v a -> PatternRow v a
splitRowBuiltin
:: Var v
=> Int
-> PatternRow v
-> [(PatternP (), [([(v,Reference)], PatternRow v)])]
splitRowBuiltin i (PR (splitAt i -> (pl, sp : pr)) g b)
| Just p <- matchBuiltin sp = [(p, [([], PR (pl ++ pr) g b)])]
| otherwise = []
splitRowBuiltin _ _ = error "splitRowBuiltin: bad index"
renameRow :: Var v => Map v v -> PatternRow v -> PatternRow v
renameRow m (PR p0 g0 b0) = PR p g b
where
access k
| Just v <- Map.lookup k m = v
| otherwise = k
p = (fmap.fmap.fmap) access p0
p = (fmap.fmap) access p0
g = changeVars m <$> g0
b = changeVars m b0
buildMatrix
:: Var v
=> [([(a,v)], PatternRow v a)]
-> ([(a,v)], PatternMatrix v a)
=> [([(v,Reference)], PatternRow v)]
-> ([(v,Reference)], PatternMatrix v)
buildMatrix [] = error "buildMatrix: empty rows"
buildMatrix vrs@((avs,_):_) = (avs, PM $ fixRow <$> vrs)
buildMatrix vrs@((avrs,_):_) = (avrs, PM $ fixRow <$> vrs)
where
cvs = snd <$> avs
fixRow (fmap snd -> rvs, pr)
cvs = fst <$> avrs
fixRow (fmap fst -> rvs, pr)
= renameRow (fromListWith const . zip rvs $ cvs) pr
splitMatrixBuiltin
:: Var v
=> Int
-> PatternMatrix v
-> [(Either (PatternP ()) Int, [(v,Reference)], PatternMatrix v)]
splitMatrixBuiltin i (PM rs)
= fmap (\(a,(b,c)) -> (Left a,b,c))
. toList
. fmap buildMatrix
. fromListWith (++)
$ splitRowBuiltin i =<< rs
splitMatrix
:: Var v
=> Int
-> PatternMatrix v a
-> [(PatternP a, [(a, v)], PatternMatrix v a)]
splitMatrix i (PM rs)
= map (\(a, (b, c)) -> (a,b,c)) . toList . fmap buildMatrix $ mmap
-> Cons
-> PatternMatrix v
-> [(Either (PatternP ()) Int, [(v,Reference)], PatternMatrix v)]
splitMatrix i cons (PM rs)
= fmap (\(a, (b, c)) -> (a,b,c)) . (fmap.fmap) buildMatrix $ mmap
where
mmap = Map.fromListWith (++) $ splitRow i <$> rs
mmap = zipWith (\t fs -> (Right t , splitRow i t fs =<< rs)) [0..] cons
type PPM v a = State (Set v, [v], Map v v) a
@ -170,73 +222,83 @@ renameTo to from
, insertWith (error "renameTo: duplicate rename") from to rn
)
prepareAs :: Var v => PatternP a -> v -> PPM v (PatternV a v)
prepareAs (UnboundP a) u = pure $ VarP (a, u)
prepareAs :: Var v => PatternP a -> v -> PPM v (PatternV v)
prepareAs (UnboundP _) u = pure $ UnboundP u
prepareAs (AsP _ p) u = prepareAs p u <* (renameTo u =<< useVar)
prepareAs (VarP a) u = VarP (a, u) <$ (renameTo u =<< useVar)
prepareAs (ConstructorP a r i ps) u = do
ConstructorP (a,u) r i <$> traverse preparePattern ps
prepareAs (EffectPureP a p) u = do
EffectPureP (a,u) <$> preparePattern p
prepareAs (EffectBindP a r i ps k) u = do
EffectBindP (a,u) r i
prepareAs (VarP _) u = UnboundP u <$ (renameTo u =<< useVar)
prepareAs (ConstructorP _ r i ps) u = do
ConstructorP u r i <$> traverse preparePattern ps
prepareAs (EffectPureP _ p) u = do
EffectPureP u <$> preparePattern p
prepareAs (EffectBindP _ r i ps k) u = do
EffectBindP u r i
<$> traverse preparePattern ps
<*> preparePattern k
prepareAs (SequenceLiteralP a ps) u = do
SequenceLiteralP (a,u) <$> traverse preparePattern ps
prepareAs (SequenceOpP a p op q) u = do
flip (SequenceOpP (a,u)) op
prepareAs (SequenceLiteralP _ ps) u = do
SequenceLiteralP u <$> traverse preparePattern ps
prepareAs (SequenceOpP _ p op q) u = do
flip (SequenceOpP u) op
<$> preparePattern p
<*> preparePattern q
prepareAs p u = pure $ (,u) <$> p
prepareAs p u = pure $ u <$ p
preparePattern :: Var v => PatternP a -> PPM v (PatternV a v)
preparePattern (VarP a) = VarP . (a,) <$> useVar
preparePattern :: Var v => PatternP a -> PPM v (PatternV v)
preparePattern (UnboundP _) = UnboundP <$> freshVar
preparePattern (VarP _) = UnboundP <$> useVar
preparePattern (AsP _ p) = prepareAs p =<< useVar
preparePattern p = prepareAs p =<< freshVar
varp :: PatternP a -> PatternP a
varp p = VarP $ loc p
buildPattern :: Reference -> Int -> [v] -> [Reference] -> PatternP ()
buildPattern r t vs rs = ConstructorP () r t vps
where
vps | length vs < length rs
= UnboundP () <$ rs
| otherwise
= VarP () <$ vs
chopPattern :: PatternP a -> PatternP a
chopPattern (ConstructorP a r i ps)
= ConstructorP a r i $ varp <$> ps
chopPattern (EffectBindP a r i ps k)
= EffectBindP a r i (varp <$> ps) (varp k)
chopPattern (EffectPureP a p)
= EffectPureP a (varp p)
chopPattern (SequenceLiteralP a ps)
= SequenceLiteralP a $ varp <$> ps
chopPattern (SequenceOpP a p op q)
= SequenceOpP a (varp p) op (varp q)
chopPattern p = p
compile
:: (Var v, Monoid a) => [Term v a] -> PatternMatrix v a -> Term v a
compile _ (PM [])
compile :: Var v => DataSpec -> [Scrut v] -> PatternMatrix v -> Term v
compile _ _ (PM [])
= error "compile: empty matrix" -- TODO: maybe generate error term
compile tms m@(PM (r:rs))
compile spec scs m@(PM (r:rs))
| rowIrrefutable r
= rebind tms r $ case guard r of
= case guard r of
Nothing -> body r
Just g -> iff mempty g (body r) $ compile tms (PM rs)
| otherwise
= case splitAt i tms of
(tmsl, scrut : tmsr) -> match mempty scrut $ f tmsl tmsr <$> sm
_ -> error "inconsistent terms and pattern matrix"
Just g -> iff mempty g (body r) $ compile spec scs (PM rs)
| (scsl, (scrut,r) : scsr) <- splitAt i scs
, r `member` builtinCase
= match () (var () scrut)
$ buildCase spec r [] scsl scsr <$> splitMatrixBuiltin i m
| (scsl, (scrut,r) : scsr) <- splitAt i scs
= case Map.lookup r spec of
Just cons ->
match () (var () scrut)
$ buildCase spec r cons scsl scsr
<$> splitMatrix i cons m
Nothing -> error $ "unknown data reference: " ++ show r
where
i = choose heuristics m
sm = splitMatrix i m
f tmsl tmsr (p, vs, m)
= MatchCase (chopPattern p) Nothing . absChain' vs
$ compile tms m
where
tms | VarP _ <- p = tmsl ++ tmsr
| otherwise = tmsl ++ fmap (uncurry var) vs ++ tmsr
compile _ _ _ = error "inconsistent terms and pattern matrix"
mkRow :: Var v => MatchCase a (Term v a) -> PatternRow v a
mkRow (MatchCase p0 g0 (AbsN' vs b))
= case runState (preparePattern p0) (avoid, vs, mempty) of
buildCase
:: Var v
=> DataSpec
-> Reference
-> Cons
-> [Scrut v]
-> [Scrut v]
-> (Either (PatternP ()) Int, [(v,Reference)], PatternMatrix v)
-> MatchCase () (Term v)
buildCase spec r cons scsl scsr (epi, vrs, m)
= MatchCase pat Nothing . absChain' vs $ compile spec scs m
where
pat | Left p <- epi = p
| Right t <- epi = buildPattern r t vs $ cons !! t
(scsn, vs) = unzip $ vrs <&> \(v,r) -> ((v,r),((),v))
scs = scsl ++ scsn ++ scsr
mkRow :: Var v => v -> MatchCase a (Term v) -> PatternRow v
mkRow sv (MatchCase p0 g0 (AbsN' vs b))
= case runState (prepareAs p0 sv) (avoid, vs, mempty) of
(p, (_, [], rn)) -> PR [p] (changeVars rn <$> g) (changeVars rn b)
_ -> error "mkRow: not all variables used"
where
@ -246,14 +308,59 @@ mkRow (MatchCase p0 g0 (AbsN' vs b))
| otherwise -> error "mkRow: guard variables do not match body"
Nothing -> Nothing
_ -> error "mkRow: impossible"
avoid = fromList vs <> maybe mempty freeVars g <> freeVars b
mkRow _ = error "mkRow: impossible"
avoid = fromList (sv:vs) <> maybe mempty freeVars g <> freeVars b
mkRow _ _ = error "mkRow: impossible"
splitPatterns
:: (Var v, Monoid a) => Term v a -> Term v a
splitPatterns = visitPure $ \case
Match' sc0 cs0 -> Just . compile [sc] . PM $ mkRow <$> cs
initialize
:: Var v
=> Reference
-> Term v
-> [MatchCase () (Term v)]
-> (Maybe v, Scrut v, PatternMatrix v)
initialize r sc cs = (lv, (sv, r), PM $ mkRow sv <$> cs)
where
avoid = freeVars sc
(lv, sv) | Var' v <- sc = (Nothing, v)
| pv <- freshIn avoid $ typed Pattern
= (Just pv, pv)
splitPatterns :: Var v => DataSpec -> Term v -> Term v
splitPatterns spec = visitPure $ \case
Match' sc0 cs0
| Just r <- determineType cs0
, (lv, scrut, pm) <- initialize r sc cs
, body <- compile spec [scrut] pm
-> Just $ case lv of
Just v -> let1 False [(((),v), sc)] body
_ -> body
where
sc = splitPatterns sc0
cs = fmap splitPatterns <$> cs0
sc = splitPatterns spec sc0
cs = fmap (splitPatterns spec) <$> cs0
_ -> Nothing
builtinCase :: Set Reference
builtinCase
= fromList
[ Rf.intRef
, Rf.natRef
, Rf.floatRef
, Rf.textRef
, Rf.charRef
]
determineType :: [MatchCase a b] -> Maybe Reference
determineType = foldr ((<|>) . f . p) Nothing
where
p (MatchCase p _ _) = p
f (AsP _ p) = f p
f IntP{} = Just Rf.intRef
f NatP{} = Just Rf.natRef
f FloatP{} = Just Rf.floatRef
f BooleanP{} = Just Rf.booleanRef
f TextP{} = Just Rf.textRef
f CharP{} = Just Rf.charRef
f SequenceLiteralP{} = Just Rf.vectorRef
f SequenceOpP{} = Just Rf.vectorRef
f (ConstructorP _ r _ _) = Just r
f (EffectBindP _ r _ _ _) = Just r
f _ = Nothing

View File

@ -14,9 +14,10 @@ import Data.Word (Word64)
import Unison.Util.EnumContainers as EC
import Unison.Term (unannotate)
import Unison.Symbol (Symbol)
import Unison.Reference (Reference(Builtin))
import Unison.Runtime.Pattern (splitPatterns)
import Unison.Runtime.Pattern
import Unison.Runtime.ANF
( superNormalize
, lamLift
@ -73,6 +74,9 @@ multRec
\ _ -> f (##Nat.+ acc n) (##Nat.sub i 1)\n\
\ ##todo (##Nat.== (f 0 1000) 5000)"
dataSpec :: DataSpec
dataSpec = mempty
testEval :: String -> Test ()
testEval s = testEval0 (env aux) main
where
@ -80,7 +84,8 @@ testEval s = testEval0 (env aux) main
= emitCombs (bit 24)
. superNormalize builtins (builtinTypeNumbering Map.!)
. lamLift
. splitPatterns
. splitPatterns dataSpec
. unannotate
$ tm s
nested :: String