mirror of
https://github.com/github/semantic.git
synced 2024-12-20 21:31:48 +03:00
Merge pull request #740 from github/intmap-intersections-in-rws
IntMap intersections in RWS
This commit is contained in:
commit
08b591db7b
@ -1,4 +1,4 @@
|
|||||||
{-# LANGUAGE DataKinds, GADTs, RankNTypes, TypeOperators #-}
|
{-# LANGUAGE DataKinds, GADTs, RankNTypes, ScopedTypeVariables, TypeOperators #-}
|
||||||
module Data.RandomWalkSimilarity
|
module Data.RandomWalkSimilarity
|
||||||
( rws
|
( rws
|
||||||
, pqGramDecorator
|
, pqGramDecorator
|
||||||
@ -20,20 +20,19 @@ import Control.Monad.State
|
|||||||
import Data.Functor.Both hiding (fst, snd)
|
import Data.Functor.Both hiding (fst, snd)
|
||||||
import Data.Functor.Foldable as Foldable
|
import Data.Functor.Foldable as Foldable
|
||||||
import Data.Hashable
|
import Data.Hashable
|
||||||
|
import qualified Data.IntMap as IntMap
|
||||||
import qualified Data.KdTree.Static as KdTree
|
import qualified Data.KdTree.Static as KdTree
|
||||||
import qualified Data.List as List
|
import qualified Data.List as List
|
||||||
import Data.Record
|
import Data.Record
|
||||||
import qualified Data.Vector as Vector
|
import qualified Data.Vector as Vector
|
||||||
import Patch
|
import Patch
|
||||||
import Prologue
|
import Prologue
|
||||||
import Term ()
|
import Term (termSize)
|
||||||
import Test.QuickCheck hiding (Fixed)
|
import Test.QuickCheck hiding (Fixed)
|
||||||
import Test.QuickCheck.Random
|
import Test.QuickCheck.Random
|
||||||
import Data.List (intersectBy)
|
|
||||||
import Term (termSize)
|
|
||||||
|
|
||||||
-- | Given a function comparing two terms recursively, and a function to compute a Hashable label from an unpacked term, compute the diff of a pair of lists of terms using a random walk similarity metric, which completes in log-linear time. This implementation is based on the paper [_RWS-Diff—Flexible and Efficient Change Detection in Hierarchical Data_](https://github.com/github/semantic-diff/files/325837/RWS-Diff.Flexible.and.Efficient.Change.Detection.in.Hierarchical.Data.pdf).
|
-- | Given a function comparing two terms recursively, and a function to compute a Hashable label from an unpacked term, compute the diff of a pair of lists of terms using a random walk similarity metric, which completes in log-linear time. This implementation is based on the paper [_RWS-Diff—Flexible and Efficient Change Detection in Hierarchical Data_](https://github.com/github/semantic-diff/files/325837/RWS-Diff.Flexible.and.Efficient.Change.Detection.in.Hierarchical.Data.pdf).
|
||||||
rws :: (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
|
rws :: forall f fields. (Eq (Record fields), Prologue.Foldable f, Functor f, Eq (f (Cofree f (Record fields))), HasField fields (Vector.Vector Double))
|
||||||
=> (Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
|
=> (Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
|
||||||
-> [Cofree f (Record fields)] -- ^ The list of old terms.
|
-> [Cofree f (Record fields)] -- ^ The list of old terms.
|
||||||
-> [Cofree f (Record fields)] -- ^ The list of new terms.
|
-> [Cofree f (Record fields)] -- ^ The list of new terms.
|
||||||
@ -42,22 +41,24 @@ rws compare as bs
|
|||||||
| null as, null bs = []
|
| null as, null bs = []
|
||||||
| null as = inserting <$> bs
|
| null as = inserting <$> bs
|
||||||
| null bs = deleting <$> as
|
| null bs = deleting <$> as
|
||||||
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, fas, fbs)) $ traverse findNearestNeighbourTo fbs
|
| otherwise = fmap snd . uncurry deleteRemaining . (`runState` (negate 1, toMap fas, toMap fbs)) $ traverse findNearestNeighbourTo fbs
|
||||||
where fas = zipWith featurize [0..] as
|
where fas = zipWith featurize [0..] as
|
||||||
fbs = zipWith featurize [0..] bs
|
fbs = zipWith featurize [0..] bs
|
||||||
kdas = KdTree.build (Vector.toList . feature) fas
|
kdas = KdTree.build (Vector.toList . feature) fas
|
||||||
kdbs = KdTree.build (Vector.toList . feature) fbs
|
kdbs = KdTree.build (Vector.toList . feature) fbs
|
||||||
featurize index term = UnmappedTerm index (getField (extract term)) term
|
featurize index term = UnmappedTerm index (getField (extract term)) term
|
||||||
|
toMap = IntMap.fromList . fmap (termIndex &&& identity)
|
||||||
|
findNearestNeighbourTo :: UnmappedTerm (Cofree f (Record fields)) -> State (Int, IntMap (UnmappedTerm (Cofree f (Record fields))), IntMap (UnmappedTerm (Cofree f (Record fields)))) (Int, Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))
|
||||||
findNearestNeighbourTo kv@(UnmappedTerm j _ b) = do
|
findNearestNeighbourTo kv@(UnmappedTerm j _ b) = do
|
||||||
(previous, unmappedA, unmappedB) <- get
|
(previous, unmappedA, unmappedB) <- get
|
||||||
fromMaybe (insertion previous unmappedA unmappedB kv) $ do
|
fromMaybe (insertion previous unmappedA unmappedB kv) $ do
|
||||||
foundA@(UnmappedTerm i _ a) <- nearestUnmapped unmappedA kdas kv
|
foundA@(UnmappedTerm i _ a) <- nearestUnmapped unmappedA kdas kv
|
||||||
foundB@(UnmappedTerm j' _ _) <- nearestUnmapped unmappedB kdbs foundA
|
UnmappedTerm j' _ _ <- nearestUnmapped unmappedB kdbs foundA
|
||||||
guard (j == j')
|
guard (j == j')
|
||||||
guard (previous <= i && i <= previous + defaultMoveBound)
|
guard (previous <= i && i <= previous + defaultMoveBound)
|
||||||
compared <- compare a b
|
compared <- compare a b
|
||||||
pure $! do
|
pure $! do
|
||||||
put (i, List.delete foundA unmappedA, List.delete foundB unmappedB)
|
put (i, IntMap.delete i unmappedA, IntMap.delete j unmappedB)
|
||||||
pure (i, compared)
|
pure (i, compared)
|
||||||
|
|
||||||
-- | Finds the most-similar unmapped term to the passed-in term, if any.
|
-- | Finds the most-similar unmapped term to the passed-in term, if any.
|
||||||
@ -65,10 +66,11 @@ rws compare as bs
|
|||||||
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
|
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
|
||||||
--
|
--
|
||||||
-- cf §4.2 of RWS-Diff
|
-- cf §4.2 of RWS-Diff
|
||||||
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (intersectBy ((==) `on` termIndex) unmapped (KdTree.kNearest tree defaultL key)))
|
nearestUnmapped :: IntMap (UnmappedTerm (Cofree f (Record fields))) -> KdTree.KdTree Double (UnmappedTerm (Cofree f (Record fields))) -> UnmappedTerm (Cofree f (Record fields)) -> Maybe (UnmappedTerm (Cofree f (Record fields)))
|
||||||
|
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (toList (IntMap.intersection unmapped (toMap (KdTree.kNearest tree defaultL key)))))
|
||||||
|
|
||||||
insertion previous unmappedA unmappedB kv@(UnmappedTerm _ _ b) = do
|
insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do
|
||||||
put (previous, unmappedA, List.delete kv unmappedB)
|
put (previous, unmappedA, IntMap.delete j unmappedB)
|
||||||
pure (negate 1, inserting b)
|
pure (negate 1, inserting b)
|
||||||
|
|
||||||
deleteRemaining diffs (_, unmappedA, _) = foldl' (flip (List.insertBy (comparing fst))) diffs ((termIndex &&& deleting . term) <$> unmappedA)
|
deleteRemaining diffs (_, unmappedA, _) = foldl' (flip (List.insertBy (comparing fst))) diffs ((termIndex &&& deleting . term) <$> unmappedA)
|
||||||
|
Loading…
Reference in New Issue
Block a user