1
1
mirror of https://github.com/github/semantic.git synced 2024-12-19 04:41:47 +03:00

Merge pull request #740 from github/intmap-intersections-in-rws

IntMap intersections in RWS
This commit is contained in:
Josh Vera 2016-08-19 13:16:03 -04:00 committed by GitHub
commit 08b591db7b

View File

@ -1,4 +1,4 @@
{-# LANGUAGE DataKinds, GADTs, RankNTypes, TypeOperators #-}
{-# LANGUAGE DataKinds, GADTs, RankNTypes, ScopedTypeVariables, TypeOperators #-}
module Data.RandomWalkSimilarity
( rws
, pqGramDecorator
@ -20,20 +20,19 @@ import Control.Monad.State
import Data.Functor.Both hiding (fst, snd)
import Data.Functor.Foldable as Foldable
import Data.Hashable
import qualified Data.IntMap as IntMap
import qualified Data.KdTree.Static as KdTree
import qualified Data.List as List
import Data.Record
import qualified Data.Vector as Vector
import Patch
import Prologue
import Term ()
import Term (termSize)
import Test.QuickCheck hiding (Fixed)
import Test.QuickCheck.Random
import Data.List (intersectBy)
import Term (termSize)
-- | Given a function comparing two terms recursively, and a function to compute a Hashable label from an unpacked term, compute the diff of a pair of lists of terms using a random walk similarity metric, which completes in log-linear time. This implementation is based on the paper [_RWS-Diff—Flexible and Efficient Change Detection in Hierarchical Data_](https://github.com/github/semantic-diff/files/325837/RWS-Diff.Flexible.and.Efficient.Change.Detection.in.Hierarchical.Data.pdf).
rws :: (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
rws :: forall f fields. (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
=> (Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
-> [Cofree f (Record fields)] -- ^ The list of old terms.
-> [Cofree f (Record fields)] -- ^ The list of new terms.
@ -42,22 +41,24 @@ rws compare as bs
| null as, null bs = []
| null as = inserting <$> bs
| null bs = deleting <$> as
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, fas, fbs)) $ traverse findNearestNeighbourTo fbs
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, toMap fas, toMap fbs)) $ traverse findNearestNeighbourTo fbs
where fas = zipWith featurize [0..] as
fbs = zipWith featurize [0..] bs
kdas = KdTree.build (Vector.toList . feature) fas
kdbs = KdTree.build (Vector.toList . feature) fbs
featurize index term = UnmappedTerm index (getField (extract term)) term
toMap = IntMap.fromList . fmap (termIndex &&& identity)
findNearestNeighbourTo :: UnmappedTerm (Cofree f (Record fields)) -> State (Int, IntMap (UnmappedTerm (Cofree f (Record fields))), IntMap (UnmappedTerm (Cofree f (Record fields)))) (Int, Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))
findNearestNeighbourTo kv@(UnmappedTerm j _ b) = do
(previous, unmappedA, unmappedB) <- get
fromMaybe (insertion previous unmappedA unmappedB kv) $ do
foundA@(UnmappedTerm i _ a) <- nearestUnmapped unmappedA kdas kv
foundB@(UnmappedTerm j' _ _) <- nearestUnmapped unmappedB kdbs foundA
UnmappedTerm j' _ _ <- nearestUnmapped unmappedB kdbs foundA
guard (j == j')
guard (previous <= i && i <= previous + defaultMoveBound)
compared <- compare a b
pure $! do
put (i, List.delete foundA unmappedA, List.delete foundB unmappedB)
put (i, IntMap.delete i unmappedA, IntMap.delete j unmappedB)
pure (i, compared)
-- | Finds the most-similar unmapped term to the passed-in term, if any.
@ -65,10 +66,11 @@ rws compare as bs
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
--
-- cf §4.2 of RWS-Diff
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (intersectBy ((==) `on` termIndex) unmapped (KdTree.kNearest tree defaultL key)))
nearestUnmapped :: IntMap (UnmappedTerm (Cofree f (Record fields))) -> KdTree.KdTree Double (UnmappedTerm (Cofree f (Record fields))) -> UnmappedTerm (Cofree f (Record fields)) -> Maybe (UnmappedTerm (Cofree f (Record fields)))
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (toList (IntMap.intersection unmapped (toMap (KdTree.kNearest tree defaultL key)))))
insertion previous unmappedA unmappedB kv@(UnmappedTerm _ _ b) = do
put (previous, unmappedA, List.delete kv unmappedB)
insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do
put (previous, unmappedA, IntMap.delete j unmappedB)
pure (negate 1, inserting b)
deleteRemaining diffs (_, unmappedA, _) = foldl' (flip (List.insertBy (comparing fst))) diffs ((termIndex &&& deleting . term) <$> unmappedA)