mirror of
https://github.com/github/semantic.git
synced 2024-12-19 04:41:47 +03:00
Merge pull request #740 from github/intmap-intersections-in-rws
IntMap intersections in RWS
This commit is contained in:
commit
08b591db7b
@ -1,4 +1,4 @@
|
||||
{-# LANGUAGE DataKinds, GADTs, RankNTypes, TypeOperators #-}
|
||||
{-# LANGUAGE DataKinds, GADTs, RankNTypes, ScopedTypeVariables, TypeOperators #-}
|
||||
module Data.RandomWalkSimilarity
|
||||
( rws
|
||||
, pqGramDecorator
|
||||
@ -20,20 +20,19 @@ import Control.Monad.State
|
||||
import Data.Functor.Both hiding (fst, snd)
|
||||
import Data.Functor.Foldable as Foldable
|
||||
import Data.Hashable
|
||||
import qualified Data.IntMap as IntMap
|
||||
import qualified Data.KdTree.Static as KdTree
|
||||
import qualified Data.List as List
|
||||
import Data.Record
|
||||
import qualified Data.Vector as Vector
|
||||
import Patch
|
||||
import Prologue
|
||||
import Term ()
|
||||
import Term (termSize)
|
||||
import Test.QuickCheck hiding (Fixed)
|
||||
import Test.QuickCheck.Random
|
||||
import Data.List (intersectBy)
|
||||
import Term (termSize)
|
||||
|
||||
-- | Given a function comparing two terms recursively, and a function to compute a Hashable label from an unpacked term, compute the diff of a pair of lists of terms using a random walk similarity metric, which completes in log-linear time. This implementation is based on the paper [_RWS-Diff—Flexible and Efficient Change Detection in Hierarchical Data_](https://github.com/github/semantic-diff/files/325837/RWS-Diff.Flexible.and.Efficient.Change.Detection.in.Hierarchical.Data.pdf).
|
||||
rws :: (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
|
||||
rws :: forall f fields. (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
|
||||
=> (Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
|
||||
-> [Cofree f (Record fields)] -- ^ The list of old terms.
|
||||
-> [Cofree f (Record fields)] -- ^ The list of new terms.
|
||||
@ -42,22 +41,24 @@ rws compare as bs
|
||||
| null as, null bs = []
|
||||
| null as = inserting <$> bs
|
||||
| null bs = deleting <$> as
|
||||
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, fas, fbs)) $ traverse findNearestNeighbourTo fbs
|
||||
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, toMap fas, toMap fbs)) $ traverse findNearestNeighbourTo fbs
|
||||
where fas = zipWith featurize [0..] as
|
||||
fbs = zipWith featurize [0..] bs
|
||||
kdas = KdTree.build (Vector.toList . feature) fas
|
||||
kdbs = KdTree.build (Vector.toList . feature) fbs
|
||||
featurize index term = UnmappedTerm index (getField (extract term)) term
|
||||
toMap = IntMap.fromList . fmap (termIndex &&& identity)
|
||||
findNearestNeighbourTo :: UnmappedTerm (Cofree f (Record fields)) -> State (Int, IntMap (UnmappedTerm (Cofree f (Record fields))), IntMap (UnmappedTerm (Cofree f (Record fields)))) (Int, Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))
|
||||
findNearestNeighbourTo kv@(UnmappedTerm j _ b) = do
|
||||
(previous, unmappedA, unmappedB) <- get
|
||||
fromMaybe (insertion previous unmappedA unmappedB kv) $ do
|
||||
foundA@(UnmappedTerm i _ a) <- nearestUnmapped unmappedA kdas kv
|
||||
foundB@(UnmappedTerm j' _ _) <- nearestUnmapped unmappedB kdbs foundA
|
||||
UnmappedTerm j' _ _ <- nearestUnmapped unmappedB kdbs foundA
|
||||
guard (j == j')
|
||||
guard (previous <= i && i <= previous + defaultMoveBound)
|
||||
compared <- compare a b
|
||||
pure $! do
|
||||
put (i, List.delete foundA unmappedA, List.delete foundB unmappedB)
|
||||
put (i, IntMap.delete i unmappedA, IntMap.delete j unmappedB)
|
||||
pure (i, compared)
|
||||
|
||||
-- | Finds the most-similar unmapped term to the passed-in term, if any.
|
||||
@ -65,10 +66,11 @@ rws compare as bs
|
||||
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
|
||||
--
|
||||
-- cf §4.2 of RWS-Diff
|
||||
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (intersectBy ((==) `on` termIndex) unmapped (KdTree.kNearest tree defaultL key)))
|
||||
nearestUnmapped :: IntMap (UnmappedTerm (Cofree f (Record fields))) -> KdTree.KdTree Double (UnmappedTerm (Cofree f (Record fields))) -> UnmappedTerm (Cofree f (Record fields)) -> Maybe (UnmappedTerm (Cofree f (Record fields)))
|
||||
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (toList (IntMap.intersection unmapped (toMap (KdTree.kNearest tree defaultL key)))))
|
||||
|
||||
insertion previous unmappedA unmappedB kv@(UnmappedTerm _ _ b) = do
|
||||
put (previous, unmappedA, List.delete kv unmappedB)
|
||||
insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do
|
||||
put (previous, unmappedA, IntMap.delete j unmappedB)
|
||||
pure (negate 1, inserting b)
|
||||
|
||||
deleteRemaining diffs (_, unmappedA, _) = foldl' (flip (List.insertBy (comparing fst))) diffs ((termIndex &&& deleting . term) <$> unmappedA)
|
||||
|
Loading…
Reference in New Issue
Block a user