1
1
mirror of https://github.com/github/semantic.git synced 2025-01-09 00:56:32 +03:00
semantic/src/RWS.hs
2017-09-19 09:36:12 -07:00

351 lines
19 KiB
Haskell
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{-# LANGUAGE GADTs, DataKinds, RankNTypes, TypeOperators #-}
module RWS
( rws
, ComparabilityRelation
, FeatureVector(..)
, defaultFeatureVectorDecorator
, featureVectorDecorator
, pqGramDecorator
, Gram(..)
, defaultD
, equalTerms
) where
import Control.Applicative (empty)
import Control.Arrow ((&&&))
import Control.Monad.State.Strict
import Data.Align.Generic
import Data.Foldable
import Data.Function ((&))
import Data.Functor.Foldable
import Data.Hashable
import Data.List (sortOn)
import Data.Maybe
import Data.Record
import Data.Semigroup hiding (First(..))
import Data.These
import Data.Traversable
import Term
import Data.Array.Unboxed
import Data.Functor.Classes
import Diff (DiffF(..), deleting, inserting, merge, replacing)
import SES
import Data.KdMap.Static hiding (elems, empty)
import qualified Data.IntMap as IntMap
import Control.Monad.Random
import System.Random.Mersenne.Pure64
type Label f fields label = forall b. TermF f (Record fields) b -> label
-- | A relation on 'Term's, guaranteed constant-time in the size of the 'Term' by parametricity.
--
-- This is used both to determine whether two root terms can be compared in O(1), and, recursively, to determine whether two nodes are equal in O(n); thus, comparability is defined s.t. two terms are equal if they are recursively comparable subterm-wise.
type ComparabilityRelation syntax ann1 ann2 = forall a b. TermF syntax ann1 a -> TermF syntax ann2 b -> Bool
newtype FeatureVector = FV { unFV :: UArray Int Double }
deriving (Eq, Ord, Show)
-- | A term which has not yet been mapped by `rws`, along with its feature vector summary & index.
data UnmappedTerm syntax ann = UnmappedTerm
{ termIndex :: {-# UNPACK #-} !Int -- ^ The index of the term within its root term.
, feature :: {-# UNPACK #-} !FeatureVector -- ^ Feature vector
, term :: Term syntax ann -- ^ The unmapped term
}
-- | Either a `term`, an index of a matched term, or nil.
data TermOrIndexOrNone term = Term term | Index {-# UNPACK #-} !Int | None
rws :: (Eq1 syntax, Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax (Record (FeatureVector ': fields1)) (Record (FeatureVector ': fields2))
-> (Term syntax (Record (FeatureVector ': fields1)) -> Term syntax (Record (FeatureVector ': fields2)) -> Bool)
-> [Term syntax (Record (FeatureVector ': fields1))]
-> [Term syntax (Record (FeatureVector ': fields2))]
-> RWSEditScript syntax (Record (FeatureVector ': fields1)) (Record (FeatureVector ': fields2))
rws _ _ as [] = This <$> as
rws _ _ [] bs = That <$> bs
rws canCompare _ [a] [b] = if canCompareTerms canCompare a b then [These a b] else [That b, This a]
rws canCompare equivalent as bs =
let sesDiffs = ses equivalent as bs
(featureAs, featureBs, mappedDiffs, allDiffs) = genFeaturizedTermsAndDiffs sesDiffs
(diffs, remaining) = findNearestNeighboursToDiff canCompare allDiffs featureAs featureBs
diffs' = deleteRemaining diffs remaining
rwsDiffs = insertMapped mappedDiffs diffs'
in fmap snd rwsDiffs
-- | An IntMap of unmapped terms keyed by their position in a list of terms.
type UnmappedTerms syntax ann = IntMap.IntMap (UnmappedTerm syntax ann)
type Edit syntax ann1 ann2 = These (Term syntax ann1) (Term syntax ann2)
-- A Diff paired with both its indices
type MappedDiff syntax ann1 ann2 = (These Int Int, Edit syntax ann1 ann2)
type RWSEditScript syntax ann1 ann2 = [Edit syntax ann1 ann2]
insertMapped :: Foldable t
=> t (MappedDiff syntax ann1 ann2)
-> [MappedDiff syntax ann1 ann2]
-> [MappedDiff syntax ann1 ann2]
insertMapped diffs into = foldl' (flip insertDiff) into diffs
deleteRemaining :: Traversable t
=> [MappedDiff syntax ann1 ann2]
-> t (UnmappedTerm syntax ann1)
-> [MappedDiff syntax ann1 ann2]
deleteRemaining diffs unmappedAs =
foldl' (flip insertDiff) diffs ((This . termIndex &&& This . term) <$> unmappedAs)
-- | Inserts an index and diff pair into a list of indices and diffs.
insertDiff :: MappedDiff syntax ann1 ann2
-> [MappedDiff syntax ann1 ann2]
-> [MappedDiff syntax ann1 ann2]
insertDiff inserted [] = [ inserted ]
insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of
(These i1 i2, These j1 j2) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest
(This i, This j) -> if i <= j then a : b : rest else b : insertDiff a rest
(That i, That j) -> if i <= j then a : b : rest else b : insertDiff a rest
(This i, These j _) -> if i <= j then a : b : rest else b : insertDiff a rest
(That i, These _ j) -> if i <= j then a : b : rest else b : insertDiff a rest
(This _, That _) -> b : insertDiff a rest
(That _, This _) -> b : insertDiff a rest
(These i1 i2, _) -> case break (isThese . fst) rest of
(rest, tail) -> let (before, after) = foldr' (combine i1 i2) ([], []) (b : rest) in
case after of
[] -> before <> insertDiff a tail
_ -> before <> (a : after) <> tail
where
combine i1 i2 each (before, after) = case fst each of
This j1 -> if i1 <= j1 then (before, each : after) else (each : before, after)
That j2 -> if i2 <= j2 then (before, each : after) else (each : before, after)
These _ _ -> (before, after)
findNearestNeighboursToDiff :: (Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax ann1 ann2 -- ^ A relation determining whether two terms can be compared.
-> [TermOrIndexOrNone (UnmappedTerm syntax ann2)]
-> [UnmappedTerm syntax ann1]
-> [UnmappedTerm syntax ann2]
-> ([MappedDiff syntax ann1 ann2], UnmappedTerms syntax ann1)
findNearestNeighboursToDiff canCompare allDiffs featureAs featureBs = (diffs, remaining)
where
(diffs, (_, remaining, _)) =
traverse (findNearestNeighbourToDiff' canCompare (toKdMap featureAs) (toKdMap featureBs)) allDiffs &
fmap catMaybes &
(`runState` (minimumTermIndex featureAs, toMap featureAs, toMap featureBs))
findNearestNeighbourToDiff' :: (Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax ann1 ann2 -- ^ A relation determining whether two terms can be compared.
-> KdMap Double FeatureVector (UnmappedTerm syntax ann1)
-> KdMap Double FeatureVector (UnmappedTerm syntax ann2)
-> TermOrIndexOrNone (UnmappedTerm syntax ann2)
-> State (Int, UnmappedTerms syntax ann1, UnmappedTerms syntax ann2)
(Maybe (MappedDiff syntax ann1 ann2))
findNearestNeighbourToDiff' canCompare kdTreeA kdTreeB termThing = case termThing of
None -> pure Nothing
RWS.Term term -> Just <$> findNearestNeighbourTo canCompare kdTreeA kdTreeB term
Index i -> modify' (\ (_, unA, unB) -> (i, unA, unB)) >> pure Nothing
-- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any), marking both as ineligible for future matches.
findNearestNeighbourTo :: (Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax ann1 ann2 -- ^ A relation determining whether two terms can be compared.
-> KdMap Double FeatureVector (UnmappedTerm syntax ann1)
-> KdMap Double FeatureVector (UnmappedTerm syntax ann2)
-> UnmappedTerm syntax ann2
-> State (Int, UnmappedTerms syntax ann1, UnmappedTerms syntax ann2)
(MappedDiff syntax ann1 ann2)
findNearestNeighbourTo canCompare kdTreeA kdTreeB term@(UnmappedTerm j _ b) = do
(previous, unmappedA, unmappedB) <- get
fromMaybe (insertion previous unmappedA unmappedB term) $ do
-- Look up the nearest unmapped term in `unmappedA`.
foundA@(UnmappedTerm i _ a) <- nearestUnmapped canCompare (termsWithinMoveBoundsFrom previous unmappedA) kdTreeA term
-- Look up the nearest `foundA` in `unmappedB`
UnmappedTerm j' _ _ <- nearestUnmapped (flip canCompare) (termsWithinMoveBoundsFrom (pred j) unmappedB) kdTreeB foundA
-- Return Nothing if their indices don't match
guard (j == j')
guard (canCompareTerms canCompare a b)
pure $! do
put (i, IntMap.delete i unmappedA, IntMap.delete j unmappedB)
pure (These i j, These a b)
where termsWithinMoveBoundsFrom bound = IntMap.filterWithKey (\ k _ -> isInMoveBounds bound k)
isInMoveBounds :: Int -> Int -> Bool
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
-- | Finds the most-similar unmapped term to the passed-in term, if any.
--
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
--
-- cf §4.2 of RWS-Diff
nearestUnmapped :: (Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax ann1 ann2 -- ^ A relation determining whether two terms can be compared.
-> UnmappedTerms syntax ann1 -- ^ A set of terms eligible for matching against.
-> KdMap Double FeatureVector (UnmappedTerm syntax ann1) -- ^ The k-d tree to look up nearest neighbours within.
-> UnmappedTerm syntax ann2 -- ^ The term to find the nearest neighbour to.
-> Maybe (UnmappedTerm syntax ann1) -- ^ The most similar unmapped term, if any.
nearestUnmapped canCompare unmapped tree key = listToMaybe (sortOn approximateEditDistance candidates)
where candidates = toList (IntMap.intersection unmapped (toMap (fmap snd (kNearest tree defaultL (feature key)))))
approximateEditDistance = editDistanceIfComparable (flip canCompare) (term key) . term
editDistanceIfComparable :: (Foldable syntax, Functor syntax, GAlign syntax)
=> ComparabilityRelation syntax ann1 ann2
-> Term syntax ann1
-> Term syntax ann2
-> Int
editDistanceIfComparable canCompare a b = if canCompareTerms canCompare a b
then editDistanceUpTo defaultM (These a b)
else maxBound
defaultD, defaultL, defaultP, defaultQ, defaultMoveBound :: Int
defaultD = 15
defaultL = 2
defaultP = 2
defaultQ = 3
defaultMoveBound = 2
-- Returns a state (insertion index, old unmapped terms, new unmapped terms), and value of (index, inserted diff),
-- given a previous index, two sets of umapped terms, and an unmapped term to insert.
insertion :: Int
-> UnmappedTerms syntax ann1
-> UnmappedTerms syntax ann2
-> UnmappedTerm syntax ann2
-> State (Int, UnmappedTerms syntax ann1, UnmappedTerms syntax ann2)
(MappedDiff syntax ann1 ann2)
insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do
put (previous, unmappedA, IntMap.delete j unmappedB)
pure (That j, That b)
genFeaturizedTermsAndDiffs :: Functor syntax
=> RWSEditScript syntax (Record (FeatureVector ': fields1)) (Record (FeatureVector ': fields2))
-> ( [UnmappedTerm syntax (Record (FeatureVector ': fields1))]
, [UnmappedTerm syntax (Record (FeatureVector ': fields2))]
, [MappedDiff syntax (Record (FeatureVector ': fields1)) (Record (FeatureVector ': fields2))]
, [TermOrIndexOrNone (UnmappedTerm syntax (Record (FeatureVector ': fields2)))]
)
genFeaturizedTermsAndDiffs sesDiffs = let Mapping _ _ a b c d = foldl' combine (Mapping 0 0 [] [] [] []) sesDiffs in (reverse a, reverse b, reverse c, reverse d)
where combine (Mapping counterA counterB as bs mappedDiffs allDiffs) diff = case diff of
This term -> Mapping (succ counterA) counterB (featurize counterA term : as) bs mappedDiffs (None : allDiffs)
That term -> Mapping counterA (succ counterB) as (featurize counterB term : bs) mappedDiffs (RWS.Term (featurize counterB term) : allDiffs)
These a b -> Mapping (succ counterA) (succ counterB) as bs ((These counterA counterB, These a b) : mappedDiffs) (Index counterA : allDiffs)
data Mapping syntax ann1 ann2
= Mapping
{-# UNPACK #-} !Int
{-# UNPACK #-} !Int
![UnmappedTerm syntax ann1]
![UnmappedTerm syntax ann2]
![MappedDiff syntax ann1 ann2]
![TermOrIndexOrNone (UnmappedTerm syntax ann2)]
featurize :: Functor syntax => Int -> Term syntax (Record (FeatureVector ': fields)) -> UnmappedTerm syntax (Record (FeatureVector ': fields))
featurize index term = UnmappedTerm index (getField (extract term)) (eraseFeatureVector term)
eraseFeatureVector :: Functor syntax => Term syntax (Record (FeatureVector ': fields)) -> Term syntax (Record (FeatureVector ': fields))
eraseFeatureVector (Term.Term (In record functor)) = termIn (setFeatureVector record nullFeatureVector) functor
nullFeatureVector :: FeatureVector
nullFeatureVector = FV $ listArray (0, 0) [0]
setFeatureVector :: Record (FeatureVector ': fields) -> FeatureVector -> Record (FeatureVector ': fields)
setFeatureVector = setField
minimumTermIndex :: [UnmappedTerm syntax ann] -> Int
minimumTermIndex = pred . maybe 0 getMin . getOption . foldMap (Option . Just . Min . termIndex)
toMap :: [UnmappedTerm syntax ann] -> IntMap.IntMap (UnmappedTerm syntax ann)
toMap = IntMap.fromList . fmap (termIndex &&& id)
toKdMap :: [UnmappedTerm syntax ann] -> KdMap Double FeatureVector (UnmappedTerm syntax ann)
toKdMap = build (elems . unFV) . fmap (feature &&& id)
-- | A `Gram` is a fixed-size view of some portion of a tree, consisting of a `stem` of _p_ labels for parent nodes, and a `base` of _q_ labels of sibling nodes. Collectively, the bag of `Gram`s for each node of a tree (e.g. as computed by `pqGrams`) form a summary of the tree.
data Gram label = Gram { stem :: [Maybe label], base :: [Maybe label] }
deriving (Eq, Show)
-- | Annotates a term with a feature vector at each node, using the default values for the p, q, and d parameters.
defaultFeatureVectorDecorator
:: (Hashable label, Traversable f)
=> Label f fields label
-> Term f (Record fields)
-> Term f (Record (FeatureVector ': fields))
defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD
-- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions.
featureVectorDecorator :: (Hashable label, Traversable f) => Label f fields label -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (FeatureVector ': fields))
featureVectorDecorator getLabel p q d
= cata collect
. pqGramDecorator getLabel p q
where collect (In (gram :. rest) functor) = termIn (foldl' addSubtermVector (unitVector d (hash gram)) functor :. rest) functor
addSubtermVector :: Functor f => FeatureVector -> Term f (Record (FeatureVector ': fields)) -> FeatureVector
addSubtermVector v term = addVectors v (rhead (extract term))
addVectors :: FeatureVector -> FeatureVector -> FeatureVector
addVectors (FV as) (FV bs) = FV $ listArray (0, d - 1) (fmap (\ i -> as ! i + bs ! i) [0..(d - 1)])
-- | Annotates a term with the corresponding p,q-gram at each node.
pqGramDecorator
:: Traversable f
=> Label f fields label -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functors constructor, but not any recursive values inside the functor (since theyre held parametric in 'b').
-> Int -- ^ 'p'; the desired stem length for the grams.
-> Int -- ^ 'q'; the desired base length for the grams.
-> Term f (Record fields) -- ^ The term to decorate.
-> Term f (Record (Gram label ': fields)) -- ^ The decorated term.
pqGramDecorator getLabel p q = cata algebra
where
algebra term = let label = getLabel term in
termIn (gram label :. termAnnotation term) (assignParentAndSiblingLabels (termOut term) label)
gram label = Gram (padToSize p []) (padToSize q (pure (Just label)))
assignParentAndSiblingLabels functor label = (`evalState` (replicate (q `div` 2) Nothing <> siblingLabels functor)) (for functor (assignLabels label))
assignLabels :: Functor f
=> label
-> Term f (Record (Gram label ': fields))
-> State [Maybe label] (Term f (Record (Gram label ': fields)))
assignLabels label (Term.Term (In (gram :. rest) functor)) = do
labels <- get
put (drop 1 labels)
pure $! termIn (gram { stem = padToSize p (Just label : stem gram), base = padToSize q labels } :. rest) functor
siblingLabels :: Traversable f => f (Term f (Record (Gram label ': fields))) -> [Maybe label]
siblingLabels = foldMap (base . rhead . extract)
padToSize n list = take n (list <> repeat empty)
-- | Computes a unit vector of the specified dimension from a hash.
unitVector :: Int -> Int -> FeatureVector
unitVector d hash = FV $ listArray (0, d - 1) ((* invMagnitude) <$> components)
where
invMagnitude = 1 / sqrt (sum (fmap (** 2) components))
components = evalRand (sequenceA (replicate d (liftRand randomDouble))) (pureMT (fromIntegral hash))
-- | Test the comparability of two root 'Term's in O(1).
canCompareTerms :: ComparabilityRelation syntax ann1 ann2 -> Term syntax ann1 -> Term syntax ann2 -> Bool
canCompareTerms canCompare t1 t2 = canCompare (unTerm t1) (unTerm t2)
-- | Recursively test the equality of two 'Term's in O(n).
equalTerms :: Eq1 syntax => ComparabilityRelation syntax ann1 ann2 -> Term syntax ann1 -> Term syntax ann2 -> Bool
equalTerms canCompare = go
where go a b = canCompareTerms canCompare a b && liftEq go (termOut (unTerm a)) (termOut (unTerm b))
-- | How many nodes to consider for our constant-time approximation to tree edit distance.
defaultM :: Integer
defaultM = 10
-- | Return an edit distance as the sum of it's term sizes, given an cutoff and a syntax of terms 'f a'.
-- | Computes a constant-time approximation to the edit distance of a diff. This is done by comparing at most _m_ nodes, & assuming the rest are zero-cost.
editDistanceUpTo :: (GAlign syntax, Foldable syntax, Functor syntax) => Integer -> Edit syntax ann1 ann2 -> Int
editDistanceUpTo m = these termSize termSize (\ a b -> diffCost m (approximateDiff a b))
where diffCost = flip . cata $ \ diff m -> case diff of
_ | m <= 0 -> 0
Merge body -> sum (fmap ($ pred m) body)
body -> succ (sum (fmap ($ pred m) body))
approximateDiff a b = maybe (replacing a b) (merge (extract a, extract b)) (galignWith (these deleting inserting approximateDiff) (unwrap a) (unwrap b))
-- Instances
instance Hashable label => Hashable (Gram label) where
hashWithSalt _ = hash
hash gram = hash (stem gram <> base gram)