Totally redid hashing scheme for let rec

This commit is contained in:
Paul Chiusano 2015-04-21 15:40:31 -04:00
parent 23b4e33fa4
commit 0cc5a3adab
5 changed files with 142 additions and 126 deletions

View File

@ -4,22 +4,22 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Unison.ABT
(ABT(..),abs,pattern Abs',at,Focus1,focus,freshIn,freshIn',hash,into
,modify,rename,subst,tm,v',unabs,var,var',Term(..),V,pattern Var') where
module Unison.ABT where
import Control.Applicative
import Data.Aeson (ToJSON(..),FromJSON(..))
import Data.Foldable (Foldable)
import Data.Functor.Classes (Eq1(..))
import Data.List
import Data.List hiding (cycle)
import Data.Ord
import Data.Set (Set)
import Data.Traversable
import Data.Text (Text)
import Data.Vector ((!))
import Prelude hiding (abs)
import Prelude hiding (abs,cycle)
import Unison.Symbol (Symbol)
import Data.Bytes.Serial (Serial(..), Serial1(..))
import Data.Bytes.VarInt (VarInt(..))
@ -38,13 +38,17 @@ type V = Symbol
data ABT f a
= Var V
| Cycle a
| Abs V a
| Tm (f a) deriving (Functor, Foldable, Traversable)
data Term f = Term { freevars :: Set V, out :: ABT f (Term f) }
pattern Var' v <- Term _ (Var v)
pattern Cycle' vs t <- Term _ (Cycle (AbsN' vs t))
pattern Abs' v body <- Term _ (Abs v body)
pattern AbsN' vs body <- (unabs -> (vs, body))
pattern Tm' f <- Term _ (Tm f)
v' :: Text -> V
v' = Symbol.prefix
@ -62,24 +66,29 @@ tm :: Foldable f => f (Term f) -> Term f
tm t = Term (Set.unions (fmap freevars (Foldable.toList t)))
(Tm t)
cycle :: Term f -> Term f
cycle t = Term (freevars t) (Cycle t)
into :: Foldable f => ABT f (Term f) -> Term f
into abt = case abt of
Var x -> var x
Cycle t -> cycle t
Abs v a -> abs v a
Tm t -> tm t
fresh :: (V -> Bool) -> V -> V
fresh used v | used v = fresh used (Symbol.freshen v)
fresh _ v = v
-- | renames `old` to `new` in the given term, ignoring subtrees that bind `old`
rename :: (Foldable f, Functor f) => V -> V -> Term f -> Term f
rename old new (Term _ t) = case t of
Var v -> if v == old then var new else var old
Cycle body -> cycle (rename old new body)
Abs v body -> if v == old then abs v body
else abs v (rename old new body)
Tm v -> tm (fmap (rename old new) v)
fresh :: (V -> Bool) -> V -> V
fresh used v | used v = fresh used (Symbol.freshen v)
fresh _ v = v
-- | Produce a variable which is free in both terms
freshInBoth :: Term f -> Term f -> V -> V
freshInBoth t1 t2 x = fresh (memberOf (freevars t1) (freevars t2)) x
@ -96,6 +105,7 @@ subst :: (Foldable f, Functor f) => Term f -> V -> Term f -> Term f
subst t x body = case out body of
Var v | x == v -> t
Var v -> var v
Cycle body -> cycle (subst t x body)
Abs x e -> abs x' e'
where x' = freshInBoth t body x
-- rename x to something that cannot be captured
@ -125,6 +135,9 @@ focus :: Foldable f
focus [] t = Just (t, id)
focus path@(hd:tl) t = case out t of
Var _ -> Nothing
Cycle t ->
let f (t,replace) = (t, cycle . replace)
in f <$> focus path t
Abs v t ->
let f (t,replace) = (t, abs v . replace)
in f <$> focus path t
@ -133,23 +146,33 @@ focus path@(hd:tl) t = case out t of
(t,replace) <- focus tl t
pure (t, tm . hreplace . replace)
hash :: (Foldable f, Digest.Digestable1 f) => Term f -> Digest.Hash
hash t = hash' [] t
hash :: forall f . (Foldable f, Digest.Digestable1 f) => Term f -> Digest.Hash
hash t = hash' [] t where
hash' :: [Either [V] V] -> Term f -> Digest.Hash
hash' env (Term _ t) = case t of
Var v -> maybe die hashInt ind
where lookup (Left cycle) = elem v cycle
lookup (Right v') = v == v'
ind = findIndex lookup env
-- env not likely to be very big, prefer to encode in one byte if possible
hashInt i = Digest.run (serialize (VarInt i))
die = error $ "unknown var in environment: " ++ show v
Cycle (AbsN' vs t) -> hash' (Left vs : env) t
Cycle t -> hash' env t
Abs v t -> hash' (Right v : env) t
Tm t -> Digest.run (Digest.digest1 (hashCycle env) (hash' env) $ t)
hash' :: (Foldable f, Digest.Digestable1 f) => [V] -> Term f -> Digest.Hash
hash' env (Term _ t) = case t of
Var v -> maybe die hashInt (elemIndex v env)
where die = error $ "unknown var in environment: " ++ show v
-- env not likely to be very big, prefer to encode in one byte if possible
hashInt i = Digest.run (serialize (VarInt i))
Abs v body -> hash' (v:env) body
Tm body -> Digest.digest1 (canonicalPermutation env) hash $ body
-- | Collapse all outer `Abs` ctors to a single `Abs`, by renaming all inner
-- `Abs` ctors to the name of the outermost `Abs`.
conflate :: (Functor f, Foldable f) => Term f -> Term f
conflate (Term _ (Abs v1 (Term _ (Abs v2 body)))) = conflate (abs v1 (rename v2 v1 body))
conflate t = t
hashCycle :: [Either [V] V] -> [Term f] -> Digest.DigestM (Term f -> Digest.Hash)
hashCycle env@(Left cycle : envTl) ts | length cycle == length ts =
let
permute p xs = case Vector.fromList xs of xs -> map (xs !) p
hashed = map (\(i,t) -> ((i,t), hash' env t)) (zip [0..] ts)
pt = map fst (sortBy (comparing snd) hashed)
(p,ts') = unzip pt
in case map Right (permute p cycle) ++ envTl of
env -> Foldable.traverse_ (serialize . hash' env) ts'
*> pure (hash' env)
hashCycle env ts = Foldable.traverse_ (serialize . hash' env) ts *> pure (hash' env)
unabs :: Term f -> ([V], Term f)
unabs (Term _ (Abs hd body)) =
@ -159,23 +182,11 @@ unabs t = ([], t)
reabs :: [V] -> Term f -> Term f
reabs vs t = foldr abs t vs
canonicalPermutation :: (Foldable f, Digest.Digestable1 f) => [V] -> [Term f] -> [Term f]
canonicalPermutation env ts =
let
permute p xs = case Vector.fromList xs of xs -> map (xs !) p
conflateds = map (hash' env . conflate) ts
-- the canonical permutation, which we get by sorting by hash
p = map fst (sortBy (comparing snd) (zip [0 :: Int ..] conflateds))
in
-- apply the canonical permutation to `ts`, then ensure each term introduces
-- its vars in the same order as this permutation
map (\t -> case unabs t of (vs, body) -> reabs (permute p vs) body)
(permute p ts)
instance (Foldable f, Functor f, Eq1 f) => Eq (Term f) where
-- alpha equivalence, works by renaming any aligned Abs ctors to use a common fresh variable
t1 == t2 = go (out t1) (out t2) where
go (Var v) (Var v2) | v == v2 = True
go (Cycle t1) (Cycle t2) = t1 == t2
go (Abs v1 body1) (Abs v2 body2) =
if v1 == v2 then body1 == body2
else let v3 = freshInBoth body1 body2 v1
@ -186,6 +197,7 @@ instance (Foldable f, Functor f, Eq1 f) => Eq (Term f) where
instance J.ToJSON1 f => ToJSON (Term f) where
toJSON (Term _ e) = case e of
Var v -> J.array [J.text "Var", toJSON v]
Cycle body -> J.array [J.text "Cycle", toJSON body]
Abs v body -> J.array [J.text "Abs", toJSON v, toJSON body]
Tm v -> J.array [J.text "Tm", J.toJSON1 v]
@ -193,19 +205,22 @@ instance (Foldable f, J.FromJSON1 f) => FromJSON (Term f) where
parseJSON j = do
t <- J.at0 (Aeson.withText "ABT.tag" pure) j
case t of
_ | t == "Var" -> var <$> J.at 1 Aeson.parseJSON j
_ | t == "Abs" -> abs <$> J.at 1 Aeson.parseJSON j <*> J.at 2 Aeson.parseJSON j
_ | t == "Tm" -> tm <$> J.at 1 J.parseJSON1 j
_ -> fail ("unknown tag: " ++ Text.unpack t)
_ | t == "Var" -> var <$> J.at 1 Aeson.parseJSON j
_ | t == "Cycle" -> cycle <$> J.at 1 Aeson.parseJSON j
_ | t == "Abs" -> abs <$> J.at 1 Aeson.parseJSON j <*> J.at 2 Aeson.parseJSON j
_ | t == "Tm" -> tm <$> J.at 1 J.parseJSON1 j
_ -> fail ("unknown tag: " ++ Text.unpack t)
instance (Foldable f, Serial1 f) => Serial (Term f) where
serialize (Term _ e) = case e of
Var v -> Put.putWord8 0 *> serialize v
Abs v body -> Put.putWord8 1 *> serialize v *> serialize body
Tm v -> Put.putWord8 2 *> serializeWith serialize v
Cycle body -> Put.putWord8 1 *> serialize body
Abs v body -> Put.putWord8 2 *> serialize v *> serialize body
Tm v -> Put.putWord8 3 *> serializeWith serialize v
deserialize = Get.getWord8 >>= \b -> case b of
0 -> var <$> deserialize
1 -> abs <$> deserialize <*> deserialize
2 -> tm <$> deserializeWith deserialize
1 -> cycle <$> deserialize
2 -> abs <$> deserialize <*> deserialize
3 -> tm <$> deserializeWith deserialize
_ -> fail ("unknown byte tag, expected one of {0,1,2}, got: " ++ show b)

View File

@ -11,12 +11,10 @@
module Unison.A_Term where
import Control.Applicative
import Control.Monad
import Data.Aeson.TH
import Data.Bytes.Serial
import Data.Foldable (Foldable, traverse_)
import Data.Functor.Classes
import Data.Maybe (listToMaybe)
import Data.Vector (Vector, (!?))
import GHC.Generics
import Data.Text (Text)
@ -46,8 +44,11 @@ data F a
| Ann a T.Type
| Vector (Vector a)
| Lam a
-- Invariant: let rec blocks have an outer an IntroLetRec, then an abs introductions for
-- each binding, then a LetRec for the bindings themselves
| IntroLetRec a
| LetRec [a] a
| Let [a] a
| Let a a
deriving (Eq,Foldable,Functor,Generic1)
-- | Terms are represented as ABTs over the base functor F.
@ -62,9 +63,9 @@ pattern App' f x <- (ABT.out -> ABT.Tm (App f x))
pattern Ann' x t <- (ABT.out -> ABT.Tm (Ann x t))
pattern Vector' xs <- (ABT.out -> ABT.Tm (Vector xs))
pattern Lam' v body <- (ABT.out -> ABT.Tm (Lam (ABT.Term _ (ABT.Abs v body))))
pattern Let' bs e reconstruct rec <- (unLet -> Just (bs,e,reconstruct,rec))
pattern LetNonrec' bs e <- Let' bs e _ False
pattern LetRec' bs e <- Let' bs e _ True
pattern Let1' v b e <- (ABT.out -> ABT.Tm (Let b (ABT.Abs' v e)))
pattern Let' bs e relet rec <- (unLets -> Just (bs,e,relet,rec))
pattern LetRec' bs e <- (unLetRec -> Just (bs,e))
-- some smart constructors
@ -93,35 +94,44 @@ lam v body = ABT.tm (Lam (ABT.abs v body))
-- reference any other binding in the block in its body (including itself),
-- and the output expression may also reference any binding in the block.
letRec :: [(ABT.V,Term)] -> Term -> Term
letRec bindings e =
ABT.tm (LetRec (map (intro . snd) bindings) (intro e))
letRec bindings e = ABT.tm (IntroLetRec (foldr ABT.abs z (map fst bindings)))
where
-- each e is wrapped in N abs introductions for each binding in block
intro e = foldr ABT.abs e (map fst bindings)
z = ABT.tm (LetRec (map snd bindings) e)
-- | Smart constructor for let blocks. Each binding in the block may
-- reference only previous bindings in the block, not including itself.
-- The output expression may reference any binding in the block.
let' :: [(ABT.V,Term)] -> Term -> Term
let' bindings e =
ABT.tm (Let (map intro (zip [0..] bindings)) (introAll bindings e))
let' bindings e = foldr f e bindings
where
-- each e is wrapped in introduction of all variables declared at a previous
-- bindings in the block
intro (ind, (_, e)) = introAll (take ind bindings) e
introAll bindings e = foldr ABT.abs e (map fst bindings)
f (v,b) body = ABT.tm (Let b (ABT.abs v body))
-- | Satisfies `unLet (let' bs e) == Just (bs, e, let')` and
-- `unLet (letRec bs e) == Just (bs, e, letRec)`
unLet :: Term -> Maybe ([(ABT.V, Term)], Term, [(ABT.V, Term)] -> Term -> Term, Bool)
unLet (ABT.Term _ (ABT.Tm t)) = case t of
Let bs e -> case extract bs e of (bs,e) -> Just (bs,e,let',False)
LetRec bs e -> case extract bs e of (bs,e) -> Just (bs,e,letRec,True)
-- | Satisfies
-- `unLets (letRec bs e) == Just (bs, e, letRec, True)` and
-- `unLets (let' bs e) == Just (bs, e, let', False)`
-- Useful for writing code agnostic to whether a let block is recursive or not.
unLets :: Term -> Maybe ([(ABT.V,Term)], Term, [(ABT.V,Term)] -> Term -> Term, Bool)
unLets e =
(f letRec True <$> unLetRec e) <|> (f let' False <$> unLet e)
where f mkLet rec (bs,e) = (bs,e,mkLet,rec)
-- | Satisfies `unLetRec (letRec bs e) == Just (bs, e)`
unLetRec :: Term -> Maybe ([(ABT.V, Term)], Term)
unLetRec (ABT.Term _ (ABT.Tm t)) = case t of
IntroLetRec c -> case ABT.unabs c of
(vs, ABT.out -> ABT.Tm (LetRec bs e)) | length vs == length bs -> Just (zip vs bs, e)
_ -> Nothing
_ -> Nothing
where
extract bs e = case ABT.unabs e of
(vs, e) -> (zip vs (map (snd . ABT.unabs) bs), e)
unLet _ = Nothing
unLetRec _ = Nothing
-- | Satisfies `unLet (let' bs e) == Just (bs, e)`
unLet :: Term -> Maybe ([(ABT.V, Term)], Term)
unLet t = fixup (go t) where
go (ABT.out -> ABT.Tm (Let b (ABT.Abs' v t))) =
case go t of (env,t) -> ((v,b):env, t)
go t = ([], t)
fixup ([], t) = Nothing
fixup bst = Just bst
-- Paths into terms, represented as lists of @PathElement@
@ -137,17 +147,16 @@ type Path = [PathElement]
-- | Use a @PathElement@ to compute one step into an @F a@ subexpression
focus1 :: PathElement -> ABT.Focus1 F a
-- focus1 e (IntroLetRec c) = Just (c, )
focus1 Fn (App f x) = Just (f, \f -> App f x)
focus1 Arg (App f x) = Just (x, \x -> App f x)
focus1 Body (Lam body) = Just (body, Lam)
focus1 Body (Let bs body) = Just (body, Let bs)
focus1 Body (Let b body) = Just (body, Let b)
focus1 Body (LetRec bs body) = Just (body, LetRec bs)
focus1 (Binding i) (Let bs body) =
listToMaybe (drop i bs)
>>= \b -> Just (b, \b -> Let (take i bs ++ [b] ++ drop (i+1) bs) body)
focus1 (Binding i) (LetRec bs body) =
listToMaybe (drop i bs)
>>= \b -> Just (b, \b -> LetRec (take i bs ++ [b] ++ drop (i+1) bs) body)
focus1 (Binding i) (Let b body) | i <= 0 = Just (b, \b -> Let b body)
--focus1 (Binding i) (LetRec bs body) =
-- listToMaybe (drop i bs)
-- >>= \b -> Just (b, \b -> LetRec (take i bs ++ [b] ++ drop (i+1) bs) body)
focus1 (Index i) (Vector vs) =
vs !? i >>= \v -> Just (v, \v -> Vector (Vector.update vs (Vector.singleton (i,v))))
focus1 _ _ = Nothing
@ -172,10 +181,8 @@ bindingAt :: Path -> Term -> Maybe (ABT.V, Term)
bindingAt [] _ = Nothing
bindingAt path t = do
parentPath <- parent path
Let' bs _ _ _ <- at parentPath t
Binding i <- pure (last path) -- last is ok since we know path is nonempty
guard (i < length bs && i >= 0) -- list indexing is partial for no good reason
pure (bs !! i)
Let1' v b body <- at parentPath t
pure (v, b)
-- mostly boring serialization and hashing code below ...
@ -193,19 +200,21 @@ instance J.ToJSON1 F where toJSON1 f = Aeson.toJSON f
instance J.FromJSON1 F where parseJSON1 j = Aeson.parseJSON j
instance Digest.Digestable1 F where
digest1 s hash e = case e of
Lit l -> Digest.run $ Put.putWord8 0 *> serialize l
Blank -> Digest.run $ Put.putWord8 1
Ref r -> Digest.run $ Put.putWord8 2 *> serialize r
App a a2 -> Digest.run $ Put.putWord8 3 *> serialize (hash a) *> serialize (hash a2)
Ann a t -> Digest.run $ Put.putWord8 4 *> serialize (hash a) *> serialize t
Vector as -> Digest.run $ Put.putWord8 5 *> serialize (Vector.length as)
*> traverse_ (serialize . hash) as
Lam a -> Digest.run $ Put.putWord8 6 *> serialize (hash a)
-- note: we use `s` to canonicalize the order of `a:as` before hashing the sequence
LetRec as a -> Digest.run $ Put.putWord8 7 *> traverse_ (serialize . hash) (s (a:as))
-- here, order is significant, so leave order alone
Let as a -> Digest.run $ Put.putWord8 8 *> traverse_ (serialize . hash) as
*> serialize (hash a)
digest1 hashCycle hash e = case e of
Lit l -> Put.putWord8 0 *> serialize l
Blank -> Put.putWord8 1
Ref r -> Put.putWord8 2 *> serialize r
App a a2 -> Put.putWord8 3 *> serialize (hash a) *> serialize (hash a2)
Ann a t -> Put.putWord8 4 *> serialize (hash a) *> serialize t
Vector as -> Put.putWord8 5 *> serialize (Vector.length as)
*> traverse_ (serialize . hash) as
Lam a -> Put.putWord8 6 *> serialize (hash a)
-- note: we use `hashCycle` to ensure result is independent of let binding order
LetRec as a ->
Put.putWord8 7 *> do
hash <- hashCycle as
serialize (hash a) --
-- here, order is significant, so don't use hashCycle
Let b a -> Put.putWord8 8 *> serialize (hash b) *> serialize (hash a)
deriveJSON defaultOptions ''PathElement

View File

@ -86,7 +86,7 @@ abstractLet path t = f <$> Term.focus path t where
-}
allowRec :: Term.Path -> Term.Term -> Maybe (Term.Path, Term.Term)
allowRec path t = do
Term.LetNonrec' bs e <- Term.at path t
Term.Let' bs e _ False <- Term.at path t
t' <- Term.modify (const (Term.letRec bs e)) path t
pure (path, t')
@ -113,8 +113,6 @@ floatLetOut :: Term.Path -> Term.Term -> Maybe (Term.Path, Term.Term)
floatLetOut path t = do
parentPath <- Term.parent path >>= Term.parent
parent <- Term.at parentPath t
Term.Let' innerBindings e _ _ <- Term.parent path >>= \path -> Term.at path t
(v, body) <- Term.bindingAt path t
error "todo: floatLetOut finish me"
{- Example:
@ -141,13 +139,14 @@ floatLamOut _ _ = error "floatLamOut"
inline :: Term.Path -> Term.Term -> Maybe (Term.Path, Term.Term)
inline path t = do
(v,body) <- Term.bindingAt path t
guard (not (Set.member v (ABT.freevars body))) -- can't inline recursive functions
parentPath <- Term.parent path
parent <- Term.at parentPath t
case parent of
Term.Let' [_] e _ _ -> Just (parentPath, ABT.subst body v e)
Term.Let' bs e let' _ -> Just (parentPath, ABT.subst body v (let' (filter (\(v',_) -> v' /= v) bs) e))
_ -> Nothing
error "todo - inline"
--guard (not (Set.member v (ABT.freevars body))) -- can't inline recursive functions
--parentPath <- Term.parent path
--parent <- Term.at parentPath t
--case parent of
-- Term.Let' [_] e _ _ -> Just (parentPath, ABT.subst body v e)
-- Term.Let' bs e let' _ -> Just (parentPath, ABT.subst body v (let' (filter (\(v',_) -> v' /= v) bs) e))
-- _ -> Nothing
{- Example:
let x = 1 in {let y = 2 in y*y}
@ -161,10 +160,10 @@ inline path t = do
mergeLet :: Term.Path -> Term.Term -> Maybe (Term.Path, Term.Term)
mergeLet path t = do
parentPath <- Term.parent path
(innerBindings,e,_,_) <- Term.at path t >>= Term.unLet
(outerBindings,_,let',_) <- Term.at parentPath t >>= Term.unLet
(innerBindings,e) <- Term.at path t >>= Term.unLetRec
(outerBindings,_) <- Term.at parentPath t >>= Term.unLetRec
(,) parentPath <$> Term.modify
(const $ let' (outerBindings ++ innerBindings) e)
(const $ Term.letRec (outerBindings ++ innerBindings) e)
parentPath
t

View File

@ -66,12 +66,12 @@ forall v body = ABT.tm (Forall (ABT.abs v body))
instance Digest.Digestable1 F where
digest1 _ hash e = case e of
Lit l -> Digest.run $ Put.putWord8 0 *> serialize l
Arrow a b -> Digest.run $ Put.putWord8 1 *> serialize (hash a) *> serialize (hash b)
App a b -> Digest.run $ Put.putWord8 2 *> serialize (hash a) *> serialize (hash b)
Ann a k -> Digest.run $ Put.putWord8 3 *> serialize (hash a) *> serialize k
Constrain a u -> Digest.run $ Put.putWord8 4 *> serialize (hash a) *> serialize u
Forall a -> Digest.run $ Put.putWord8 5 *> serialize (hash a)
Lit l -> Put.putWord8 0 *> serialize l
Arrow a b -> Put.putWord8 1 *> serialize (hash a) *> serialize (hash b)
App a b -> Put.putWord8 2 *> serialize (hash a) *> serialize (hash b)
Ann a k -> Put.putWord8 3 *> serialize (hash a) *> serialize k
Constrain a u -> Put.putWord8 4 *> serialize (hash a) *> serialize u
Forall a -> Put.putWord8 5 *> serialize (hash a)
instance J.ToJSON1 F where
toJSON1 f = toJSON f

View File

@ -25,16 +25,9 @@ type Hash = B.ByteString
class Functor f => Digestable1 f where
-- | Produce a hash for an `f a`, given a hashing function for `a`.
-- The first argument, @s@ can be used by the instance to produce
-- a canonical permutation of any sequence of @a@ values, useful
-- if the instance contains @a@ values whose order should not affect
-- hash results. We can think of @s@ as a sort function using some
-- ordering that the instance doesn't have to be aware of.
--
-- More precisely, @s@ will have the property that for any
-- @xs = [x1, x2, .. xN]@, @s@ will produce the same permutation of
-- @xs@ for any permutation of @xs@ as input.
digest1 :: ([a] -> [a]) -> (a -> Hash) -> f a -> Hash
-- The first argument, `hashCycle`, can be used by instances to hash
-- `a` values whose order should not affect hash results.
digest1 :: ([a] -> DigestM (a -> Hash)) -> (a -> Hash) -> f a -> Digest
run :: Digest -> B.ByteString
run d = case digest d H.hashInit of