2017-04-24 22:53:11 +03:00
{- # LANGUAGE GADTs, DataKinds, RankNTypes, TypeOperators # -}
module RWS (
rws
2017-05-31 21:20:01 +03:00
, ComparabilityRelation
2017-04-24 22:53:11 +03:00
, FeatureVector
, stripDiff
, defaultFeatureVectorDecorator
, stripTerm
, featureVectorDecorator
, pqGramDecorator
, Gram ( .. )
, defaultD
) where
2017-04-07 21:44:37 +03:00
import Prologue
import Data.Record
import Data.These
2017-04-24 22:53:11 +03:00
import Patch
2017-04-07 21:44:37 +03:00
import Term
import Data.Array
import Data.Functor.Classes
import SES
2017-04-08 00:48:14 +03:00
import qualified Data.Functor.Both as Both
2017-04-24 22:53:11 +03:00
import Data.Functor.Listable
2017-04-08 00:48:14 +03:00
import Data.KdTree.Static hiding ( toList )
2017-04-07 23:08:49 +03:00
import qualified Data.IntMap as IntMap
import Data.Semigroup ( Min ( .. ) , Option ( .. ) )
2017-04-24 22:53:11 +03:00
import Control.Monad.Random
import System.Random.Mersenne.Pure64
import Diff ( mapAnnotations )
type Label f fields label = forall b . TermF f ( Record fields ) b -> label
2017-06-01 18:50:21 +03:00
-- | A relation on 'Term's, guaranteed constant-time in the size of the 'Term' by parametricity.
--
2017-06-05 19:46:00 +03:00
-- This is used both to determine whether two root terms can be compared in O(1), and, recursively, to determine whether two nodes are equal in O(n); thus, comparability is defined s.t. two terms are equal if they are recursively comparable subterm-wise.
2017-05-31 21:20:01 +03:00
type ComparabilityRelation f fields = forall a b . TermF f ( Record fields ) a -> TermF f ( Record fields ) b -> Bool
2017-04-24 22:53:11 +03:00
type FeatureVector = Array Int Double
2017-04-07 22:42:32 +03:00
-- | A term which has not yet been mapped by `rws`, along with its feature vector summary & index.
data UnmappedTerm f fields = UnmappedTerm {
termIndex :: Int -- ^ The index of the term within its root term.
, feature :: FeatureVector -- ^ Feature vector
, term :: Term f ( Record fields ) -- ^ The unmapped term
}
-- | Either a `term`, an index of a matched term, or nil.
data TermOrIndexOrNone term = Term term | Index Int | None
2017-06-13 22:11:52 +03:00
rws :: ( HasField fields ( Maybe FeatureVector ) , Functor f , Eq1 f )
2017-04-13 19:45:18 +03:00
=> ( Diff f fields -> Int )
2017-05-31 21:20:01 +03:00
-> ComparabilityRelation f fields
2017-06-13 22:11:52 +03:00
-> [ Term f ( Record fields ) ]
-> [ Term f ( Record fields ) ]
2017-04-13 19:33:26 +03:00
-> RWSEditScript f fields
2017-06-13 22:12:04 +03:00
rws _ _ as [] = This <$> as
rws _ _ [] bs = That <$> bs
2017-06-13 22:19:43 +03:00
rws _ canCompare [ a ] [ b ] = if canCompareTerms canCompare a b then [ These a b ] else [ That b , This a ]
2017-06-15 15:46:40 +03:00
rws editDistance canCompare as bs =
let sesDiffs = ses ( equalTerms canCompare ) as bs
( featureAs , featureBs , mappedDiffs , allDiffs ) = evalState ( genFeaturizedTermsAndDiffs sesDiffs ) ( 0 , 0 )
( diffs , remaining ) = findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs
diffs' = deleteRemaining diffs remaining
rwsDiffs = insertMapped mappedDiffs diffs'
in fmap snd rwsDiffs
2017-04-07 21:44:37 +03:00
2017-04-08 00:48:14 +03:00
-- | An IntMap of unmapped terms keyed by their position in a list of terms.
type UnmappedTerms f fields = IntMap ( UnmappedTerm f fields )
2017-04-13 19:45:18 +03:00
type Diff f fields = These ( Term f ( Record fields ) ) ( Term f ( Record fields ) )
2017-04-14 21:43:48 +03:00
-- A Diff paired with both its indices
2017-04-13 19:45:18 +03:00
type MappedDiff f fields = ( These Int Int , Diff f fields )
type RWSEditScript f fields = [ Diff f fields ]
2017-04-07 21:44:37 +03:00
2017-04-13 19:45:18 +03:00
insertMapped :: Foldable t => t ( MappedDiff f fields ) -> [ MappedDiff f fields ] -> [ MappedDiff f fields ]
2017-04-12 21:46:27 +03:00
insertMapped diffs into = foldl' ( flip insertDiff ) into diffs
2017-04-12 00:10:08 +03:00
deleteRemaining :: ( Traversable t )
2017-04-13 19:45:18 +03:00
=> [ MappedDiff f fields ]
-> t ( UnmappedTerm f fields )
-> [ MappedDiff f fields ]
2017-04-12 00:10:08 +03:00
deleteRemaining diffs unmappedAs =
2017-04-12 21:46:27 +03:00
foldl' ( flip insertDiff ) diffs ( ( This . termIndex &&& This . term ) <$> unmappedAs )
2017-04-08 00:48:14 +03:00
-- | Inserts an index and diff pair into a list of indices and diffs.
2017-04-13 19:45:18 +03:00
insertDiff :: MappedDiff f fields
-> [ MappedDiff f fields ]
-> [ MappedDiff f fields ]
2017-04-08 00:48:14 +03:00
insertDiff inserted [] = [ inserted ]
insertDiff a @ ( ij1 , _ ) ( b @ ( ij2 , _ ) : rest ) = case ( ij1 , ij2 ) of
( These i1 i2 , These j1 j2 ) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest
( This i , This j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( That i , That j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( This i , These j _ ) -> if i <= j then a : b : rest else b : insertDiff a rest
( That i , These _ j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( This _ , That _ ) -> b : insertDiff a rest
( That _ , This _ ) -> b : insertDiff a rest
( These i1 i2 , _ ) -> case break ( isThese . fst ) rest of
( rest , tail ) -> let ( before , after ) = foldr' ( combine i1 i2 ) ( [] , [] ) ( b : rest ) in
case after of
[] -> before <> insertDiff a tail
_ -> before <> ( a : after ) <> tail
where
combine i1 i2 each ( before , after ) = case fst each of
This j1 -> if i1 <= j1 then ( before , each : after ) else ( each : before , after )
That j2 -> if i2 <= j2 then ( before , each : after ) else ( each : before , after )
These _ _ -> ( before , after )
findNearestNeighboursToDiff :: ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
2017-05-31 21:20:01 +03:00
-> ComparabilityRelation f fields -- ^ A relation determining whether two terms can be compared.
2017-04-08 00:48:14 +03:00
-> [ TermOrIndexOrNone ( UnmappedTerm f fields ) ]
-> [ UnmappedTerm f fields ]
-> [ UnmappedTerm f fields ]
-> ( [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] , UnmappedTerms f fields )
2017-04-12 21:46:27 +03:00
findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs = ( diffs , remaining )
where
2017-05-24 20:12:05 +03:00
( diffs , ( _ , remaining , _ ) ) =
2017-04-12 21:46:27 +03:00
traverse ( findNearestNeighbourToDiff' editDistance canCompare ( toKdTree <$> Both . both featureAs featureBs ) ) allDiffs &
fmap catMaybes &
2017-05-24 20:12:05 +03:00
( ` runState ` ( minimumTermIndex featureAs , toMap featureAs , toMap featureBs ) )
2017-04-08 00:48:14 +03:00
2017-04-13 19:45:18 +03:00
findNearestNeighbourToDiff' :: ( Diff f fields -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
2017-05-31 21:20:01 +03:00
-> ComparabilityRelation f fields -- ^ A relation determining whether two terms can be compared.
2017-04-08 00:48:14 +03:00
-> Both . Both ( KdTree Double ( UnmappedTerm f fields ) )
-> TermOrIndexOrNone ( UnmappedTerm f fields )
2017-05-24 20:12:05 +03:00
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
2017-04-13 19:45:18 +03:00
( Maybe ( MappedDiff f fields ) )
2017-04-08 00:48:14 +03:00
findNearestNeighbourToDiff' editDistance canCompare kdTrees termThing = case termThing of
None -> pure Nothing
Term term -> Just <$> findNearestNeighbourTo editDistance canCompare kdTrees term
2017-05-24 20:12:05 +03:00
Index i -> do
( _ , unA , unB ) <- get
put ( i , unA , unB )
2017-04-08 00:48:14 +03:00
pure Nothing
-- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any), marking both as ineligible for future matches.
2017-04-13 19:45:18 +03:00
findNearestNeighbourTo :: ( Diff f fields -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
2017-05-31 21:20:01 +03:00
-> ComparabilityRelation f fields -- ^ A relation determining whether two terms can be compared.
2017-04-08 00:48:14 +03:00
-> Both . Both ( KdTree Double ( UnmappedTerm f fields ) )
-> UnmappedTerm f fields
2017-05-24 20:12:05 +03:00
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
2017-04-13 19:45:18 +03:00
( MappedDiff f fields )
2017-04-08 00:48:14 +03:00
findNearestNeighbourTo editDistance canCompare kdTrees term @ ( UnmappedTerm j _ b ) = do
2017-05-24 20:12:05 +03:00
( previous , unmappedA , unmappedB ) <- get
fromMaybe ( insertion previous unmappedA unmappedB term ) $ do
2017-04-08 00:48:14 +03:00
-- Look up the nearest unmapped term in `unmappedA`.
2017-05-24 20:26:15 +03:00
foundA @ ( UnmappedTerm i _ a ) <- nearestUnmapped editDistance canCompare ( termsWithinMoveBoundsFrom previous unmappedA ) ( Both . fst kdTrees ) term
2017-04-08 00:48:14 +03:00
-- Look up the nearest `foundA` in `unmappedB`
2017-05-24 20:26:15 +03:00
UnmappedTerm j' _ _ <- nearestUnmapped editDistance canCompare ( termsWithinMoveBoundsFrom ( pred j ) unmappedB ) ( Both . snd kdTrees ) foundA
2017-04-08 00:48:14 +03:00
-- Return Nothing if their indices don't match
guard ( j == j' )
2017-05-31 21:20:01 +03:00
guard ( canCompareTerms canCompare a b )
2017-04-08 00:48:14 +03:00
pure $! do
2017-05-24 20:12:05 +03:00
put ( i , IntMap . delete i unmappedA , IntMap . delete j unmappedB )
2017-04-08 00:48:14 +03:00
pure ( These i j , These a b )
2017-05-24 20:26:15 +03:00
where termsWithinMoveBoundsFrom bound = IntMap . filterWithKey ( \ k _ -> isInMoveBounds bound k )
2017-04-08 00:48:14 +03:00
2017-04-12 19:14:36 +03:00
isInMoveBounds :: Int -> Int -> Bool
2017-04-08 00:48:14 +03:00
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
-- | Finds the most-similar unmapped term to the passed-in term, if any.
--
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
--
-- cf §4.2 of RWS-Diff
nearestUnmapped
2017-04-13 19:45:18 +03:00
:: ( Diff f fields -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
2017-05-31 21:20:01 +03:00
-> ComparabilityRelation f fields -- ^ A relation determining whether two terms can be compared.
2017-04-08 00:48:14 +03:00
-> UnmappedTerms f fields -- ^ A set of terms eligible for matching against.
-> KdTree Double ( UnmappedTerm f fields ) -- ^ The k-d tree to look up nearest neighbours within.
-> UnmappedTerm f fields -- ^ The term to find the nearest neighbour to.
-> Maybe ( UnmappedTerm f fields ) -- ^ The most similar unmapped term, if any.
nearestUnmapped editDistance canCompare unmapped tree key = getFirst $ foldMap ( First . Just ) ( sortOn ( editDistanceIfComparable editDistance canCompare ( term key ) . term ) ( toList ( IntMap . intersection unmapped ( toMap ( kNearest tree defaultL key ) ) ) ) )
2017-05-31 21:20:01 +03:00
editDistanceIfComparable :: Bounded t => ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> t ) -> ComparabilityRelation f fields -> Term f ( Record fields ) -> Term f ( Record fields ) -> t
editDistanceIfComparable editDistance canCompare a b = if canCompareTerms canCompare a b
2017-04-08 00:48:14 +03:00
then editDistance ( These a b )
else maxBound
2017-04-24 22:53:11 +03:00
defaultD , defaultL , defaultP , defaultQ , defaultMoveBound :: Int
defaultD = 15
2017-04-08 00:48:14 +03:00
defaultL = 2
2017-04-24 22:53:11 +03:00
defaultP = 2
defaultQ = 3
2017-04-08 00:48:14 +03:00
defaultMoveBound = 2
2017-04-24 22:53:11 +03:00
2017-04-08 00:48:14 +03:00
-- Returns a state (insertion index, old unmapped terms, new unmapped terms), and value of (index, inserted diff),
-- given a previous index, two sets of umapped terms, and an unmapped term to insert.
insertion :: Int
2017-05-24 20:12:05 +03:00
-> UnmappedTerms f fields
-> UnmappedTerms f fields
-> UnmappedTerm f fields
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
( MappedDiff f fields )
insertion previous unmappedA unmappedB ( UnmappedTerm j _ b ) = do
put ( previous , unmappedA , IntMap . delete j unmappedB )
2017-04-08 00:48:14 +03:00
pure ( That j , That b )
2017-04-07 22:42:32 +03:00
2017-04-13 19:45:18 +03:00
genFeaturizedTermsAndDiffs :: ( Functor f , HasField fields ( Maybe FeatureVector ) )
=> RWSEditScript f fields
-> State
( Int , Int )
( [ UnmappedTerm f fields ] , [ UnmappedTerm f fields ] , [ MappedDiff f fields ] , [ TermOrIndexOrNone ( UnmappedTerm f fields ) ] )
genFeaturizedTermsAndDiffs sesDiffs = case sesDiffs of
[] -> pure ( [] , [] , [] , [] )
( diff : diffs ) -> do
( counterA , counterB ) <- get
case diff of
This term -> do
put ( succ counterA , counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( featurize counterA term : as , bs , mappedDiffs , None : allDiffs )
That term -> do
put ( counterA , succ counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( as , featurize counterB term : bs , mappedDiffs , Term ( featurize counterB term ) : allDiffs )
These a b -> do
put ( succ counterA , succ counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( as , bs , ( These counterA counterB , These a b ) : mappedDiffs , Index counterA : allDiffs )
2017-04-07 22:42:32 +03:00
featurize :: ( HasField fields ( Maybe FeatureVector ) , Functor f ) => Int -> Term f ( Record fields ) -> UnmappedTerm f fields
featurize index term = UnmappedTerm index ( let Just v = getField ( extract term ) in v ) ( eraseFeatureVector term )
eraseFeatureVector :: ( Functor f , HasField fields ( Maybe FeatureVector ) ) => Term f ( Record fields ) -> Term f ( Record fields )
eraseFeatureVector term = let record :< functor = runCofree term in
cofree ( setFeatureVector record Nothing :< functor )
setFeatureVector :: HasField fields ( Maybe FeatureVector ) => Record fields -> Maybe FeatureVector -> Record fields
setFeatureVector = setField
2017-04-12 19:14:36 +03:00
minimumTermIndex :: [ RWS . UnmappedTerm f fields ] -> Int
2017-04-07 23:08:49 +03:00
minimumTermIndex = pred . maybe 0 getMin . getOption . foldMap ( Option . Just . Min . termIndex )
2017-04-12 19:14:36 +03:00
toMap :: [ UnmappedTerm f fields ] -> IntMap ( UnmappedTerm f fields )
2017-04-07 23:08:49 +03:00
toMap = IntMap . fromList . fmap ( termIndex &&& identity )
2017-04-12 19:14:36 +03:00
toKdTree :: [ UnmappedTerm f fields ] -> KdTree Double ( UnmappedTerm f fields )
2017-04-08 00:48:14 +03:00
toKdTree = build ( elems . feature )
2017-04-14 21:43:48 +03:00
2017-04-24 22:53:11 +03:00
-- | A `Gram` is a fixed-size view of some portion of a tree, consisting of a `stem` of _p_ labels for parent nodes, and a `base` of _q_ labels of sibling nodes. Collectively, the bag of `Gram`s for each node of a tree (e.g. as computed by `pqGrams`) form a summary of the tree.
data Gram label = Gram { stem :: [ Maybe label ] , base :: [ Maybe label ] }
deriving ( Eq , Show )
-- | Annotates a term with a feature vector at each node, using the default values for the p, q, and d parameters.
defaultFeatureVectorDecorator
:: ( Hashable label , Traversable f )
=> Label f fields label
-> Term f ( Record fields )
-> Term f ( Record ( Maybe FeatureVector ': fields ) )
defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD
-- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions.
featureVectorDecorator :: ( Hashable label , Traversable f ) => Label f fields label -> Int -> Int -> Int -> Term f ( Record fields ) -> Term f ( Record ( Maybe FeatureVector ': fields ) )
featureVectorDecorator getLabel p q d
= cata collect
. pqGramDecorator getLabel p q
where collect ( ( gram :. rest ) :< functor ) = cofree ( ( foldl' addSubtermVector ( Just ( unitVector d ( hash gram ) ) ) functor :. rest ) :< functor )
addSubtermVector :: Functor f => Maybe FeatureVector -> Term f ( Record ( Maybe FeatureVector ': fields ) ) -> Maybe FeatureVector
addSubtermVector v term = addVectors <$> v <*> rhead ( extract term )
addVectors :: Num a => Array Int a -> Array Int a -> Array Int a
addVectors as bs = listArray ( 0 , d - 1 ) ( fmap ( \ i -> as ! i + bs ! i ) [ 0 .. ( d - 1 ) ] )
-- | Annotates a term with the corresponding p,q-gram at each node.
pqGramDecorator
:: Traversable f
=> Label f fields label -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functor’ s constructor, but not any recursive values inside the functor (since they’ re held parametric in 'b').
-> Int -- ^ 'p'; the desired stem length for the grams.
-> Int -- ^ 'q'; the desired base length for the grams.
-> Term f ( Record fields ) -- ^ The term to decorate.
-> Term f ( Record ( Gram label ': fields ) ) -- ^ The decorated term.
pqGramDecorator getLabel p q = cata algebra
where
algebra term = let label = getLabel term in
cofree ( ( gram label :. headF term ) :< assignParentAndSiblingLabels ( tailF term ) label )
gram label = Gram ( padToSize p [] ) ( padToSize q ( pure ( Just label ) ) )
assignParentAndSiblingLabels functor label = ( ` evalState ` ( replicate ( q ` div ` 2 ) Nothing <> siblingLabels functor ) ) ( for functor ( assignLabels label ) )
assignLabels :: Functor f
=> label
-> Term f ( Record ( Gram label ': fields ) )
-> State [ Maybe label ] ( Term f ( Record ( Gram label ': fields ) ) )
assignLabels label a = case runCofree a of
( gram :. rest ) :< functor -> do
labels <- get
put ( drop 1 labels )
pure $! cofree ( ( gram { stem = padToSize p ( Just label : stem gram ) , base = padToSize q labels } :. rest ) :< functor )
siblingLabels :: Traversable f => f ( Term f ( Record ( Gram label ': fields ) ) ) -> [ Maybe label ]
siblingLabels = foldMap ( base . rhead . extract )
padToSize n list = take n ( list <> repeat Prologue . empty )
-- | Computes a unit vector of the specified dimension from a hash.
unitVector :: Int -> Int -> FeatureVector
unitVector d hash = fmap ( * invMagnitude ) uniform
where
uniform = listArray ( 0 , d - 1 ) ( evalRand components ( pureMT ( fromIntegral hash ) ) )
invMagnitude = 1 / sqrtDouble ( sum ( fmap ( ** 2 ) uniform ) )
components = sequenceA ( replicate d ( liftRand randomDouble ) )
2017-06-01 18:50:21 +03:00
-- | Test the comparability of two root 'Term's in O(1).
2017-05-31 21:20:01 +03:00
canCompareTerms :: ComparabilityRelation f fields -> Term f ( Record fields ) -> Term f ( Record fields ) -> Bool
canCompareTerms canCompare = canCompare ` on ` runCofree
2017-06-01 18:50:21 +03:00
-- | Recursively test the equality of two 'Term's in O(n).
2017-05-31 21:27:21 +03:00
equalTerms :: Eq1 f => ComparabilityRelation f fields -> Term f ( Record fields ) -> Term f ( Record fields ) -> Bool
equalTerms canCompare = go
where go a b = canCompareTerms canCompare a b && liftEq go ( tailF ( runCofree a ) ) ( tailF ( runCofree b ) )
2017-04-24 22:53:11 +03:00
-- | Strips the head annotation off a term annotated with non-empty records.
stripTerm :: Functor f => Term f ( Record ( h ': t ) ) -> Term f ( Record t )
stripTerm = fmap rtail
-- | Strips the head annotation off a diff annotated with non-empty records.
stripDiff
:: ( Functor f , Functor g )
=> Free ( TermF f ( g ( Record ( h ': t ) ) ) ) ( Patch ( Term f ( Record ( h ': t ) ) ) )
-> Free ( TermF f ( g ( Record t ) ) ) ( Patch ( Term f ( Record t ) ) )
stripDiff = mapAnnotations rtail
-- Instances
instance Hashable label => Hashable ( Gram label ) where
hashWithSalt _ = hash
hash gram = hash ( stem gram <> base gram )
instance Listable1 Gram where
liftTiers tiers = liftCons2 ( liftTiers ( liftTiers tiers ) ) ( liftTiers ( liftTiers tiers ) ) Gram
instance Listable a => Listable ( Gram a ) where
tiers = tiers1