1
1
mirror of https://github.com/github/semantic.git synced 2025-01-08 08:30:27 +03:00

Replace cofree with Term and TermF

This commit is contained in:
joshvera 2016-09-16 11:22:52 -04:00
parent 86a38579cd
commit f026a2d2f1
2 changed files with 62 additions and 49 deletions

View File

@ -25,7 +25,7 @@ import Data.Record
import qualified Data.Vector as Vector
import Patch
import Prologue as P
import Term (termSize, zipTerms, Term)
import Term (termSize, zipTerms, Term, TermF)
import Test.QuickCheck hiding (Fixed)
import Test.QuickCheck.Random
import qualified SES
@ -43,15 +43,12 @@ rws :: forall f fields.
(GAlign f,
Traversable f,
HasField fields Category,
Eq (f (Cofree f Category)),
HasField fields (Vector.Vector Double)) =>
(Term f (Record fields)
-> Term f (Record fields)
-> Maybe (Diff f (Record fields))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
-- | The list of old terms.
-> [Cofree f (Record fields)] -- ^ The list of new terms.
-> [Cofree f (Record fields)] -- ^ The resulting list of similarity-matched diffs.
-> [Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))]
Eq (f (Term f Category)),
HasField fields (Vector.Vector Double))
=> (Term f (Record fields) -> Term f (Record fields) -> Maybe (Diff f (Record fields))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
-> [Term f (Record fields)] -- ^ The list of old terms.
-> [Term f (Record fields)] -- ^ The list of new terms.
-> [Diff f (Record fields)] -- ^ The resulting list of similarity-matched diffs.
rws compare as bs
| null as, null bs = []
| null as = inserting <$> bs
@ -62,28 +59,17 @@ rws compare as bs
traverse findNearestNeighbourToDiff allDiffs &
fmap catMaybes &
-- Run the state with an initial state
(`runState` (pred $ maybe 0 getMin (getOption (foldMap (Option . Just . Min . termIndex) fas)),
toMap fas,
toMap fbs)) &
(`runState` (pred $ maybe 0 getMin (getOption (foldMap (Option . Just . Min . termIndex) featurizedAs)),
toMap featurizedAs,
toMap featurizedBs)) &
uncurry deleteRemaining &
insertMapped countersAndDiffs &
fmap snd
where
sesDiffs = eitherCutoff 1 <$> SES.ses replaceIfEqual cost as bs
replaceIfEqual :: HasField fields Category => Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))
replaceIfEqual a b
| (category <$> a) == (category <$> b) = hylo wrap runCofree <$> zipTerms a b
| otherwise = Nothing
cost = iter (const 0) . (1 <$)
eitherCutoff :: (Functor f) => Integer
-> Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))
-> Free (CofreeF f (Both (Record fields))) (Either (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))) (Patch (Cofree f (Record fields))))
eitherCutoff n diff | n <= 0 = pure (Left diff)
eitherCutoff n diff = free . bimap Right (eitherCutoff (pred n)) $ runFree diff
(fas, fbs, _, _, countersAndDiffs, allDiffs) =
(featurizedAs, featurizedBs, _, _, countersAndDiffs, allDiffs) =
foldl' (\(as, bs, counterA, counterB, diffs, allDiffs) diff -> case runFree diff of
Pure (Right (Delete term)) ->
(as <> pure (featurize counterA term), bs, succ counterA, counterB, diffs, allDiffs <> pure Nil)
@ -104,13 +90,6 @@ rws compare as bs
put (i, unA, unB)
pure Nothing
kdas = KdTree.build (Vector.toList . feature) fas
kdbs = KdTree.build (Vector.toList . feature) fbs
featurize index term = UnmappedTerm index (getField (extract term)) term
toMap = IntMap.fromList . fmap (termIndex &&& identity)
-- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any),
-- marking both as ineligible for future matches.
findNearestNeighbourTo :: UnmappedTerm f fields
@ -156,8 +135,6 @@ rws compare as bs
-> Maybe (UnmappedTerm f fields) -- ^ The most similar unmapped term, if any.
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (toList (IntMap.intersection unmapped (toMap (KdTree.kNearest tree defaultL key)))))
-- | Determines whether an index is in-bounds for a move given the most recently matched index.
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
insertMapped diffs into = foldl' (\into (i, mappedTerm) ->
insertDiff (i, mappedTerm) into)
into
@ -169,7 +146,36 @@ rws compare as bs
diffs
((termIndex &&& deleting . term) <$> unmappedA)
insertDiff :: (These Int Int, a) -> [(These Int Int, a)] -> [(These Int Int, a)]
replaceIfEqual :: HasField fields Category
=> Term f (Record fields)
-> Term f (Record fields)
-> Maybe (Diff f (Record fields))
replaceIfEqual a b
| (category <$> a) == (category <$> b) = hylo wrap runCofree <$> zipTerms a b
| otherwise = Nothing
cost = iter (const 0) . (1 <$)
eitherCutoff :: (Functor f) => Integer
-> Diff f (Record fields)
-> Free (TermF f (Both (Record fields)))
(Either (Diff f (Record fields)) (Patch (Term f (Record fields))))
eitherCutoff n diff | n <= 0 = pure (Left diff)
eitherCutoff n diff = free . bimap Right (eitherCutoff (pred n)) $ runFree diff
kdas = KdTree.build (Vector.toList . feature) featurizedAs
kdbs = KdTree.build (Vector.toList . feature) featurizedBs
featurize index term = UnmappedTerm index (getField (extract term)) term
toMap = IntMap.fromList . fmap (termIndex &&& identity)
-- | Determines whether an index is in-bounds for a move given the most recently matched index.
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
-- | Inserts an index and diff pair into a list of indices and diffs.
insertDiff :: (These Int Int, diff) -> [(These Int Int, diff)] -> [(These Int Int, diff)]
insertDiff inserted [] = [ inserted ]
insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of
(These i1 i2, These j1 j2) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest
@ -195,7 +201,7 @@ insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of
-- | Return an edit distance as the sum of it's term sizes, given an cutoff and a syntax of terms 'f a'.
-- | Computes a constant-time approximation to the edit distance of a diff. This is done by comparing at most _m_ nodes, & assuming the rest are zero-cost.
editDistanceUpTo :: (P.Foldable f, Functor f) => Integer
-> Free (CofreeF f (Both a)) (Patch (Cofree f a))
-> Diff f annotation
-> Int
editDistanceUpTo m = diffSum (patchSum termSize) . cutoff m
where diffSum patchCost = sum . fmap (maybe 0 patchCost)
@ -213,7 +219,7 @@ defaultM :: Integer
defaultM = 10
-- | A term which has not yet been mapped by `rws`, along with its feature vector summary & index.
data UnmappedTerm f fields = UnmappedTerm { termIndex :: {-# UNPACK #-} !Int, feature :: !(Vector.Vector Double), term :: !(Cofree f (Record fields)) }
data UnmappedTerm f fields = UnmappedTerm { termIndex :: {-# UNPACK #-} !Int, feature :: !(Vector.Vector Double), term :: !(Term f (Record fields)) }
-- | Either a `term`, an index of a matched term, or nil.
data TermOrIndexOrNil term = Term term | Index Int | Nil
@ -226,11 +232,14 @@ data Gram label = Gram { stem :: [Maybe label], base :: [Maybe label] }
deriving (Eq, Show)
-- | Annotates a term with a feature vector at each node, using the default values for the p, q, and d parameters.
defaultFeatureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields))
defaultFeatureVectorDecorator :: (Hashable label, Traversable f) =>
(forall b. TermF f (Record fields) b -> label)
-> Term f (Record fields)
-> Term f (Record (Vector.Vector Double ': fields))
defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD
-- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions.
featureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Int -> Int -> Int -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields))
featureVectorDecorator :: (Hashable label, Traversable f) => (forall b. TermF f (Record fields) b -> label) -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (Vector.Vector Double ': fields))
featureVectorDecorator getLabel p q d
= cata (\ (RCons gram rest :< functor) ->
cofree ((foldr (Vector.zipWith (+) . getField . extract) (unitVector d (hash gram)) functor .: rest) :< functor))
@ -238,11 +247,11 @@ featureVectorDecorator getLabel p q d
-- | Annotates a term with the corresponding p,q-gram at each node.
pqGramDecorator :: Traversable f
=> (forall b. CofreeF f (Record fields) b -> label) -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functors constructor, but not any recursive values inside the functor (since theyre held parametric in 'b').
-> Int -- ^ 'p'; the desired stem length for the grams.
-> Int -- ^ 'q'; the desired base length for the grams.
-> Cofree f (Record fields) -- ^ The term to decorate.
-> Cofree f (Record (Gram label ': fields)) -- ^ The decorated term.
=> (forall b. TermF f (Record fields) b -> label) -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functors constructor, but not any recursive values inside the functor (since theyre held parametric in 'b').
-> Int -- ^ 'p'; the desired stem length for the grams.
-> Int -- ^ 'q'; the desired base length for the grams.
-> Term f (Record fields) -- ^ The term to decorate.
-> Term f (Record (Gram label ': fields)) -- ^ The decorated term.
pqGramDecorator getLabel p q = cata algebra
where
algebra term = let label = getLabel term in
@ -250,13 +259,15 @@ pqGramDecorator getLabel p q = cata algebra
gram label = Gram (padToSize p []) (padToSize q (pure (Just label)))
assignParentAndSiblingLabels functor label = (`evalState` (replicate (q `div` 2) Nothing <> siblingLabels functor)) (for functor (assignLabels label))
assignLabels :: label -> Cofree f (Record (Gram label ': fields)) -> State [Maybe label] (Cofree f (Record (Gram label ': fields)))
assignLabels :: label
-> Term f (Record (Gram label ': fields))
-> State [Maybe label] (Term f (Record (Gram label ': fields)))
assignLabels label a = case runCofree a of
RCons gram rest :< functor -> do
labels <- get
put (drop 1 labels)
pure $! cofree ((gram { stem = padToSize p (Just label : stem gram), base = padToSize q labels } .: rest) :< functor)
siblingLabels :: Traversable f => f (Cofree f (Record (Gram label ': fields))) -> [Maybe label]
siblingLabels :: Traversable f => f (Term f (Record (Gram label ': fields))) -> [Maybe label]
siblingLabels = foldMap (base . rhead . extract)
padToSize n list = take n (list <> repeat empty)
@ -268,11 +279,13 @@ unitVector d hash = normalize ((`evalRand` mkQCGen hash) (sequenceA (Vector.repl
vmagnitude = sqrtDouble . Vector.sum . fmap (** 2)
-- | Strips the head annotation off a term annotated with non-empty records.
stripTerm :: Functor f => Cofree f (Record (h ': t)) -> Cofree f (Record t)
stripTerm :: Functor f => Term f (Record (h ': t)) -> Term f (Record t)
stripTerm = fmap rtail
-- | Strips the head annotation off a diff annotated with non-empty records.
stripDiff :: (Functor f, Functor g) => Free (CofreeF f (g (Record (h ': t)))) (Patch (Cofree f (Record (h ': t)))) -> Free (CofreeF f (g (Record t))) (Patch (Cofree f (Record t)))
stripDiff :: (Functor f, Functor g)
=> Free (TermF f (g (Record (h ': t)))) (Patch (Term f (Record (h ': t))))
-> Free (TermF f (g (Record t))) (Patch (Term f (Record t)))
stripDiff = iter (\ (h :< f) -> wrap (fmap rtail h :< f)) . fmap (pure . fmap stripTerm)

View File

@ -1,7 +1,7 @@
{-# LANGUAGE TypeFamilies #-}
module Term.Arbitrary where
import Data.Functor.Foldable (Base, cata, unfold, Unfoldable(embed))
import Data.Functor.Foldable (Base, unfold, Unfoldable(embed))
import Data.Text.Arbitrary ()
import Prologue
import Syntax