2017-04-08 00:48:14 +03:00
{- # LANGUAGE GADTs, DataKinds, TypeOperators # -}
2017-04-08 00:59:45 +03:00
module RWS ( rws ) where
2017-04-07 21:44:37 +03:00
import Prologue
2017-04-11 23:29:37 +03:00
import Control.Monad.Effect as Eff
2017-04-07 21:44:37 +03:00
import Control.Monad.Effect.Internal as I
import Data.Record
import Data.These
import Term
import Data.Array
import Data.Functor.Classes
import Info
import SES
2017-04-08 00:48:14 +03:00
import qualified Data.Functor.Both as Both
2017-04-07 21:44:37 +03:00
import Data.Functor.Classes.Eq.Generic
2017-04-08 00:59:45 +03:00
import Data.RandomWalkSimilarity ( FeatureVector )
2017-04-07 21:44:37 +03:00
2017-04-08 00:48:14 +03:00
import Data.KdTree.Static hiding ( toList )
2017-04-07 23:08:49 +03:00
import qualified Data.IntMap as IntMap
import Data.Semigroup ( Min ( .. ) , Option ( .. ) )
2017-04-07 21:44:37 +03:00
-- rws :: (GAlign f, Traversable f, Eq1 f, HasField fields Category, HasField fields (Maybe FeatureVector))
-- => (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-- -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared.
-- -> [Term f (Record fields)] -- ^ The list of old terms.
-- -> [Term f (Record fields)] -- ^ The list of new terms.
-- -> [These (Term f (Record fields)) (Term f (Record fields))] -- ^ The resulting list of similarity-matched diffs.
-- rws editDistance canCompare as bs = undefined
2017-04-07 22:42:32 +03:00
-- | A term which has not yet been mapped by `rws`, along with its feature vector summary & index.
data UnmappedTerm f fields = UnmappedTerm {
termIndex :: Int -- ^ The index of the term within its root term.
, feature :: FeatureVector -- ^ Feature vector
, term :: Term f ( Record fields ) -- ^ The unmapped term
}
-- | Either a `term`, an index of a matched term, or nil.
data TermOrIndexOrNone term = Term term | Index Int | None
2017-04-08 00:59:45 +03:00
rws :: ( HasField fields Category , HasField fields ( Maybe FeatureVector ) , Foldable t , Functor f , Eq1 f ) => ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -> t ( Term f ( Record fields ) ) -> t ( Term f ( Record fields ) ) -> RWSEditScript f fields
2017-04-11 23:29:37 +03:00
rws editDistance canCompare as bs = Eff . run $ RWS . run editDistance canCompare as bs rws'
2017-04-08 00:59:45 +03:00
rws' :: ( HasField fields ( Maybe FeatureVector ) , RWS f fields :< e ) => Eff e [ These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ]
rws' = do
2017-04-07 23:08:49 +03:00
sesDiffs <- ses'
( featureAs , featureBs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs' sesDiffs
2017-04-08 00:48:14 +03:00
( diffs , remaining ) <- findNearestNeighoursToDiff' allDiffs featureAs featureBs
diffs' <- deleteRemaining' diffs remaining
2017-04-12 00:10:08 +03:00
rwsDiffs <- insertMapped' mappedDiffs diffs'
2017-04-08 00:48:14 +03:00
pure ( fmap snd rwsDiffs )
2017-04-07 22:59:00 +03:00
2017-04-08 00:48:14 +03:00
ses' :: ( HasField fields ( Maybe FeatureVector ) , RWS f fields :< e ) => Eff e ( RWSEditScript f fields )
2017-04-07 23:08:49 +03:00
ses' = send SES
2017-04-08 00:48:14 +03:00
genFeaturizedTermsAndDiffs' :: ( HasField fields ( Maybe FeatureVector ) , RWS f fields :< e )
=> RWSEditScript f fields
-> Eff e ( [ UnmappedTerm f fields ] , [ UnmappedTerm f fields ] , [ ( These Int Int , Diff f fields ) ] , [ TermOrIndexOrNone ( UnmappedTerm f fields ) ] )
2017-04-07 23:08:49 +03:00
genFeaturizedTermsAndDiffs' = send . GenFeaturizedTermsAndDiffs
2017-04-08 00:48:14 +03:00
findNearestNeighoursToDiff' :: ( RWS f fields :< e )
=> [ TermOrIndexOrNone ( UnmappedTerm f fields ) ]
-> [ UnmappedTerm f fields ]
-> [ UnmappedTerm f fields ]
-> Eff e ( [ ( These Int Int , Diff f fields ) ] , UnmappedTerms f fields )
findNearestNeighoursToDiff' diffs as bs = send ( FindNearestNeighoursToDiff diffs as bs )
2017-04-07 21:44:37 +03:00
2017-04-12 19:14:36 +03:00
deleteRemaining' :: ( RWS f fields :< e )
=> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
-> UnmappedTerms f fields
-> Eff e [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
2017-04-08 00:48:14 +03:00
deleteRemaining' diffs remaining = send ( DeleteRemaining diffs remaining )
2017-04-12 19:14:36 +03:00
insertMapped' :: ( RWS f fields :< e ) => [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] -> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] -> Eff e [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
2017-04-08 00:48:14 +03:00
insertMapped' diffs mappedDiffs = send ( InsertMapped diffs mappedDiffs )
data RWS f fields result where
2017-04-07 21:44:37 +03:00
-- RWS :: RWS a b (EditScript a b)
SES :: RWS f fields ( RWSEditScript f fields )
2017-04-07 22:42:32 +03:00
-- FindNearestNeighbourToDiff :: TermOrIndexOrNone (UnmappedTerm f fields) ->
GenFeaturizedTermsAndDiffs :: HasField fields ( Maybe FeatureVector ) => RWSEditScript f fields -> RWS f fields ( [ UnmappedTerm f fields ] , [ UnmappedTerm f fields ] , [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] , [ TermOrIndexOrNone ( UnmappedTerm f fields ) ] )
2017-04-08 00:48:14 +03:00
FindNearestNeighoursToDiff :: [ TermOrIndexOrNone ( UnmappedTerm f fields ) ] -> [ UnmappedTerm f fields ] -> [ UnmappedTerm f fields ] -> RWS f fields ( [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] , UnmappedTerms f fields )
DeleteRemaining :: [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] -> UnmappedTerms f fields -> RWS f fields [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
InsertMapped :: [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] -> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] -> RWS f fields [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
2017-04-07 21:44:37 +03:00
-- EraseFeatureVector :: forall a b f fields. RwsF a b (EditScript (Term f (Record fields)) (Term f (Record fields)))
2017-04-08 00:48:14 +03:00
-- | An IntMap of unmapped terms keyed by their position in a list of terms.
type UnmappedTerms f fields = IntMap ( UnmappedTerm f fields )
2017-04-07 21:44:37 +03:00
type RWSEditScript f fields = [ These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ]
2017-04-08 00:59:45 +03:00
run :: ( Eq1 f , Functor f , HasField fields Category , HasField fields ( Maybe FeatureVector ) , Foldable t )
2017-04-08 00:48:14 +03:00
=> ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -- ^ A relation determining whether two terms can be compared.
-> t ( Term f ( Record fields ) )
-> t ( Term f ( Record fields ) )
2017-04-11 23:29:37 +03:00
-> Eff ( RWS f fields ': e ) ( RWSEditScript f fields )
-> Eff e ( RWSEditScript f fields )
2017-04-12 01:23:32 +03:00
run editDistance canCompare as bs = relay pure ( \ m q -> q $ case m of
SES -> ses ( gliftEq ( == ) ` on ` fmap category ) as bs
2017-04-11 23:29:37 +03:00
( GenFeaturizedTermsAndDiffs sesDiffs ) ->
2017-04-12 19:14:36 +03:00
evalState ( genFeaturizedTermsAndDiffs sesDiffs ) ( 0 , 0 )
2017-04-11 23:29:37 +03:00
( FindNearestNeighoursToDiff allDiffs featureAs featureBs ) ->
2017-04-12 01:23:32 +03:00
findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs
2017-04-11 23:29:37 +03:00
( DeleteRemaining allDiffs remainingDiffs ) ->
2017-04-12 01:23:32 +03:00
deleteRemaining allDiffs remainingDiffs
2017-04-11 23:29:37 +03:00
( InsertMapped allDiffs mappedDiffs ) ->
2017-04-12 01:23:32 +03:00
insertMapped allDiffs mappedDiffs )
2017-04-08 00:48:14 +03:00
type Diff f fields = These ( Term f ( Record fields ) ) ( Term f ( Record fields ) )
2017-04-12 00:10:08 +03:00
insertMapped :: Foldable t => t ( These Int Int , Diff f fields ) -> [ ( These Int Int , Diff f fields ) ] -> [ ( These Int Int , Diff f fields ) ]
2017-04-12 21:46:27 +03:00
insertMapped diffs into = foldl' ( flip insertDiff ) into diffs
2017-04-12 00:10:08 +03:00
deleteRemaining :: ( Traversable t )
=> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
-> t ( RWS . UnmappedTerm f fields )
-> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
deleteRemaining diffs unmappedAs =
2017-04-12 21:46:27 +03:00
foldl' ( flip insertDiff ) diffs ( ( This . termIndex &&& This . term ) <$> unmappedAs )
2017-04-08 00:48:14 +03:00
-- | Inserts an index and diff pair into a list of indices and diffs.
2017-04-12 01:33:59 +03:00
insertDiff :: ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) )
-> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
-> [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ]
2017-04-08 00:48:14 +03:00
insertDiff inserted [] = [ inserted ]
insertDiff a @ ( ij1 , _ ) ( b @ ( ij2 , _ ) : rest ) = case ( ij1 , ij2 ) of
( These i1 i2 , These j1 j2 ) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest
( This i , This j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( That i , That j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( This i , These j _ ) -> if i <= j then a : b : rest else b : insertDiff a rest
( That i , These _ j ) -> if i <= j then a : b : rest else b : insertDiff a rest
( This _ , That _ ) -> b : insertDiff a rest
( That _ , This _ ) -> b : insertDiff a rest
( These i1 i2 , _ ) -> case break ( isThese . fst ) rest of
( rest , tail ) -> let ( before , after ) = foldr' ( combine i1 i2 ) ( [] , [] ) ( b : rest ) in
case after of
[] -> before <> insertDiff a tail
_ -> before <> ( a : after ) <> tail
where
combine i1 i2 each ( before , after ) = case fst each of
This j1 -> if i1 <= j1 then ( before , each : after ) else ( each : before , after )
That j2 -> if i2 <= j2 then ( before , each : after ) else ( each : before , after )
These _ _ -> ( before , after )
findNearestNeighboursToDiff :: ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -- ^ A relation determining whether two terms can be compared.
-> [ TermOrIndexOrNone ( UnmappedTerm f fields ) ]
-> [ UnmappedTerm f fields ]
-> [ UnmappedTerm f fields ]
-> ( [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] , UnmappedTerms f fields )
2017-04-12 21:46:27 +03:00
findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs = ( diffs , remaining )
where
( diffs , ( _ , remaining , _ ) ) =
traverse ( findNearestNeighbourToDiff' editDistance canCompare ( toKdTree <$> Both . both featureAs featureBs ) ) allDiffs &
fmap catMaybes &
( ` runState ` ( minimumTermIndex featureAs , toMap featureAs , toMap featureBs ) )
2017-04-08 00:48:14 +03:00
findNearestNeighbourToDiff' :: ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -- ^ A relation determining whether two terms can be compared.
-> Both . Both ( KdTree Double ( UnmappedTerm f fields ) )
-> TermOrIndexOrNone ( UnmappedTerm f fields )
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
( Maybe ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) )
findNearestNeighbourToDiff' editDistance canCompare kdTrees termThing = case termThing of
None -> pure Nothing
Term term -> Just <$> findNearestNeighbourTo editDistance canCompare kdTrees term
Index i -> do
( _ , unA , unB ) <- get
put ( i , unA , unB )
pure Nothing
-- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any), marking both as ineligible for future matches.
findNearestNeighbourTo :: ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -- ^ A relation determining whether two terms can be compared.
-> Both . Both ( KdTree Double ( UnmappedTerm f fields ) )
-> UnmappedTerm f fields
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) )
findNearestNeighbourTo editDistance canCompare kdTrees term @ ( UnmappedTerm j _ b ) = do
( previous , unmappedA , unmappedB ) <- get
fromMaybe ( insertion previous unmappedA unmappedB term ) $ do
-- Look up the nearest unmapped term in `unmappedA`.
foundA @ ( UnmappedTerm i _ a ) <- nearestUnmapped editDistance canCompare ( IntMap . filterWithKey ( \ k _ ->
isInMoveBounds previous k )
unmappedA ) ( Both . fst kdTrees ) term
-- Look up the nearest `foundA` in `unmappedB`
UnmappedTerm j' _ _ <- nearestUnmapped editDistance canCompare unmappedB ( Both . snd kdTrees ) foundA
-- Return Nothing if their indices don't match
guard ( j == j' )
guard ( canCompare a b )
pure $! do
put ( i , IntMap . delete i unmappedA , IntMap . delete j unmappedB )
pure ( These i j , These a b )
2017-04-12 19:14:36 +03:00
isInMoveBounds :: Int -> Int -> Bool
2017-04-08 00:48:14 +03:00
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
-- | Finds the most-similar unmapped term to the passed-in term, if any.
--
-- RWS can produce false positives in the case of e.g. hash collisions. Therefore, we find the _l_ nearest candidates, filter out any which have already been mapped, and select the minimum of the remaining by (a constant-time approximation of) edit distance.
--
-- cf §4.2 of RWS-Diff
nearestUnmapped
:: ( These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) -> Int ) -- ^ A function computes a constant-time approximation to the edit distance between two terms.
-> ( Term f ( Record fields ) -> Term f ( Record fields ) -> Bool ) -- ^ A relation determining whether two terms can be compared.
-> UnmappedTerms f fields -- ^ A set of terms eligible for matching against.
-> KdTree Double ( UnmappedTerm f fields ) -- ^ The k-d tree to look up nearest neighbours within.
-> UnmappedTerm f fields -- ^ The term to find the nearest neighbour to.
-> Maybe ( UnmappedTerm f fields ) -- ^ The most similar unmapped term, if any.
nearestUnmapped editDistance canCompare unmapped tree key = getFirst $ foldMap ( First . Just ) ( sortOn ( editDistanceIfComparable editDistance canCompare ( term key ) . term ) ( toList ( IntMap . intersection unmapped ( toMap ( kNearest tree defaultL key ) ) ) ) )
2017-04-12 19:14:36 +03:00
editDistanceIfComparable :: Bounded t => ( These a b -> t ) -> ( a -> b -> Bool ) -> a -> b -> t
2017-04-08 00:48:14 +03:00
editDistanceIfComparable editDistance canCompare a b = if canCompare a b
then editDistance ( These a b )
else maxBound
2017-04-12 19:14:36 +03:00
defaultL , defaultMoveBound :: Int
2017-04-08 00:48:14 +03:00
defaultL = 2
defaultMoveBound = 2
-- Returns a state (insertion index, old unmapped terms, new unmapped terms), and value of (index, inserted diff),
-- given a previous index, two sets of umapped terms, and an unmapped term to insert.
insertion :: Int
-> UnmappedTerms f fields
-> UnmappedTerms f fields
-> UnmappedTerm f fields
-> State ( Int , UnmappedTerms f fields , UnmappedTerms f fields )
( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) )
insertion previous unmappedA unmappedB ( UnmappedTerm j _ b ) = do
put ( previous , unmappedA , IntMap . delete j unmappedB )
pure ( That j , That b )
2017-04-07 22:42:32 +03:00
2017-04-12 19:14:36 +03:00
-- genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) => RWSEditScript f fields -> ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)])
-- genFeaturizedTermsAndDiffs sesDiffs = (featurizedAs, featurizedBs, countersAndDiffs, allDiffs)
-- where
-- (featurizedAs, featurizedBs, _, _, countersAndDiffs, allDiffs) = foldl' (\(as, bs, counterA, counterB, diffs, allDiffs) diff ->
-- case diff of
-- This term ->
-- (as <> pure (featurize counterA term), bs, succ counterA, counterB, diffs, allDiffs <> pure None)
-- That term ->
-- (as, bs <> pure (featurize counterB term), counterA, succ counterB, diffs, allDiffs <> pure (Term (featurize counterB term)))
-- These a b ->
-- (as, bs, succ counterA, succ counterB, diffs <> pure (These counterA counterB, These a b), allDiffs <> pure (Index counterA))
-- ) ([], [], 0, 0, [], []) sesDiffs
genFeaturizedTermsAndDiffs :: ( Functor f , HasField fields ( Maybe FeatureVector ) ) => RWSEditScript f fields -> State ( Int , Int ) ( [ UnmappedTerm f fields ] , [ UnmappedTerm f fields ] , [ ( These Int Int , These ( Term f ( Record fields ) ) ( Term f ( Record fields ) ) ) ] , [ TermOrIndexOrNone ( UnmappedTerm f fields ) ] )
genFeaturizedTermsAndDiffs sesDiffs = go
2017-04-12 00:10:25 +03:00
where
2017-04-12 19:14:36 +03:00
go = case sesDiffs of
[] -> pure ( [] , [] , [] , [] )
( diff : diffs ) -> do
( counterA , counterB ) <- get
case diff of
This term -> do
put ( succ counterA , counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( featurize counterA term : as , bs , mappedDiffs , None : allDiffs )
That term -> do
put ( counterA , succ counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( as , featurize counterB term : bs , mappedDiffs , Term ( featurize counterB term ) : allDiffs )
These a b -> do
put ( succ counterA , succ counterB )
( as , bs , mappedDiffs , allDiffs ) <- genFeaturizedTermsAndDiffs diffs
pure ( as , bs , ( These counterA counterB , These a b ) : mappedDiffs , Index counterA : allDiffs )
2017-04-07 22:42:32 +03:00
featurize :: ( HasField fields ( Maybe FeatureVector ) , Functor f ) => Int -> Term f ( Record fields ) -> UnmappedTerm f fields
featurize index term = UnmappedTerm index ( let Just v = getField ( extract term ) in v ) ( eraseFeatureVector term )
eraseFeatureVector :: ( Functor f , HasField fields ( Maybe FeatureVector ) ) => Term f ( Record fields ) -> Term f ( Record fields )
eraseFeatureVector term = let record :< functor = runCofree term in
cofree ( setFeatureVector record Nothing :< functor )
setFeatureVector :: HasField fields ( Maybe FeatureVector ) => Record fields -> Maybe FeatureVector -> Record fields
setFeatureVector = setField
2017-04-12 19:14:36 +03:00
minimumTermIndex :: [ RWS . UnmappedTerm f fields ] -> Int
2017-04-07 23:08:49 +03:00
minimumTermIndex = pred . maybe 0 getMin . getOption . foldMap ( Option . Just . Min . termIndex )
2017-04-12 19:14:36 +03:00
toMap :: [ UnmappedTerm f fields ] -> IntMap ( UnmappedTerm f fields )
2017-04-07 23:08:49 +03:00
toMap = IntMap . fromList . fmap ( termIndex &&& identity )
2017-04-12 19:14:36 +03:00
toKdTree :: [ UnmappedTerm f fields ] -> KdTree Double ( UnmappedTerm f fields )
2017-04-08 00:48:14 +03:00
toKdTree = build ( elems . feature )