diff --git a/src/RWS.hs b/src/RWS.hs index a62b66cee..5bf87e221 100644 --- a/src/RWS.hs +++ b/src/RWS.hs @@ -30,7 +30,7 @@ data UnmappedTerm f fields = UnmappedTerm { data TermOrIndexOrNone term = Term term | Index Int | None rws :: (HasField fields Category, HasField fields (Maybe FeatureVector), Foldable t, Functor f, Eq1 f) - => (These (Term f (Record fields)) (Term f (Record fields)) -> Int) + => (Diff f fields -> Int) -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -> t (Term f (Record fields)) -> t (Term f (Record fields)) @@ -48,46 +48,59 @@ ses' = send SES genFeaturizedTermsAndDiffs' :: (HasField fields (Maybe FeatureVector), RWS f fields :< e) => RWSEditScript f fields - -> Eff e ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, Diff f fields)], [TermOrIndexOrNone (UnmappedTerm f fields)]) + -> Eff e ([UnmappedTerm f fields], [UnmappedTerm f fields], [MappedDiff f fields], [TermOrIndexOrNone (UnmappedTerm f fields)]) genFeaturizedTermsAndDiffs' = send . GenFeaturizedTermsAndDiffs findNearestNeighoursToDiff' :: (RWS f fields :< e) => [TermOrIndexOrNone (UnmappedTerm f fields)] -> [UnmappedTerm f fields] -> [UnmappedTerm f fields] - -> Eff e ([(These Int Int, Diff f fields)], UnmappedTerms f fields) + -> Eff e ([MappedDiff f fields], UnmappedTerms f fields) findNearestNeighoursToDiff' diffs as bs = send (FindNearestNeighoursToDiff diffs as bs) deleteRemaining' :: (RWS f fields :< e) - => [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] + => [MappedDiff f fields] -> UnmappedTerms f fields - -> Eff e [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] + -> Eff e [MappedDiff f fields] deleteRemaining' diffs remaining = send (DeleteRemaining diffs remaining) -insertMapped' :: (RWS f fields :< e) => [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] -> [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] -> Eff e [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] +insertMapped' :: (RWS f fields :< e) + => [MappedDiff f fields] + -> [MappedDiff f fields] + -> Eff e [MappedDiff f fields] insertMapped' diffs mappedDiffs = send (InsertMapped diffs mappedDiffs) data RWS f fields result where - -- RWS :: RWS a b (EditScript a b) + SES :: RWS f fields (RWSEditScript f fields) - -- FindNearestNeighbourToDiff :: TermOrIndexOrNone (UnmappedTerm f fields) -> - GenFeaturizedTermsAndDiffs :: HasField fields (Maybe FeatureVector) => RWSEditScript f fields -> RWS f fields ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)]) - FindNearestNeighoursToDiff :: [TermOrIndexOrNone (UnmappedTerm f fields)] -> [UnmappedTerm f fields] -> [UnmappedTerm f fields] -> RWS f fields ([(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], UnmappedTerms f fields) + GenFeaturizedTermsAndDiffs :: HasField fields (Maybe FeatureVector) + => RWSEditScript f fields + -> RWS f fields ([UnmappedTerm f fields], [UnmappedTerm f fields], [MappedDiff f fields], [TermOrIndexOrNone (UnmappedTerm f fields)]) - DeleteRemaining :: [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] -> UnmappedTerms f fields -> RWS f fields [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] + FindNearestNeighoursToDiff :: [TermOrIndexOrNone (UnmappedTerm f fields)] + -> [UnmappedTerm f fields] + -> [UnmappedTerm f fields] + -> RWS f fields ([MappedDiff f fields], UnmappedTerms f fields) - InsertMapped :: [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] -> [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] -> RWS f fields [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] - -- EraseFeatureVector :: forall a b f fields. RwsF a b (EditScript (Term f (Record fields)) (Term f (Record fields))) + DeleteRemaining :: [MappedDiff f fields] + -> UnmappedTerms f fields + -> RWS f fields [MappedDiff f fields] + + InsertMapped :: [MappedDiff f fields] -> [MappedDiff f fields] -> RWS f fields [MappedDiff f fields] -- | An IntMap of unmapped terms keyed by their position in a list of terms. type UnmappedTerms f fields = IntMap (UnmappedTerm f fields) -type RWSEditScript f fields = [These (Term f (Record fields)) (Term f (Record fields))] +type Diff f fields = These (Term f (Record fields)) (Term f (Record fields)) + +type MappedDiff f fields = (These Int Int, Diff f fields) + +type RWSEditScript f fields = [Diff f fields] run :: (Eq1 f, Functor f, HasField fields Category, HasField fields (Maybe FeatureVector), Foldable t) - => (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. + => (Diff f fields -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared. -> t (Term f (Record fields)) -> t (Term f (Record fields)) @@ -104,22 +117,20 @@ run editDistance canCompare as bs = relay pure (\m q -> q $ case m of (InsertMapped allDiffs mappedDiffs) -> insertMapped allDiffs mappedDiffs) -type Diff f fields = These (Term f (Record fields)) (Term f (Record fields)) - -insertMapped :: Foldable t => t (These Int Int, Diff f fields) -> [(These Int Int, Diff f fields)] -> [(These Int Int, Diff f fields)] +insertMapped :: Foldable t => t (MappedDiff f fields) -> [MappedDiff f fields] -> [MappedDiff f fields] insertMapped diffs into = foldl' (flip insertDiff) into diffs deleteRemaining :: (Traversable t) - => [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] - -> t (RWS.UnmappedTerm f fields) - -> [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] + => [MappedDiff f fields] + -> t (UnmappedTerm f fields) + -> [MappedDiff f fields] deleteRemaining diffs unmappedAs = foldl' (flip insertDiff) diffs ((This . termIndex &&& This . term) <$> unmappedAs) -- | Inserts an index and diff pair into a list of indices and diffs. -insertDiff :: (These Int Int, These (Term f (Record fields)) (Term f (Record fields))) - -> [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] - -> [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))] +insertDiff :: MappedDiff f fields + -> [MappedDiff f fields] + -> [MappedDiff f fields] insertDiff inserted [] = [ inserted ] insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of (These i1 i2, These j1 j2) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest @@ -142,8 +153,6 @@ insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of That j2 -> if i2 <= j2 then (before, each : after) else (each : before, after) These _ _ -> (before, after) - - findNearestNeighboursToDiff :: (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared. -> [TermOrIndexOrNone (UnmappedTerm f fields)] @@ -157,12 +166,12 @@ findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs fmap catMaybes & (`runState` (minimumTermIndex featureAs, toMap featureAs, toMap featureBs)) -findNearestNeighbourToDiff' :: (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. +findNearestNeighbourToDiff' :: (Diff f fields -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared. -> Both.Both (KdTree Double (UnmappedTerm f fields)) -> TermOrIndexOrNone (UnmappedTerm f fields) -> State (Int, UnmappedTerms f fields, UnmappedTerms f fields) - (Maybe (These Int Int, These (Term f (Record fields)) (Term f (Record fields)))) + (Maybe (MappedDiff f fields)) findNearestNeighbourToDiff' editDistance canCompare kdTrees termThing = case termThing of None -> pure Nothing Term term -> Just <$> findNearestNeighbourTo editDistance canCompare kdTrees term @@ -172,12 +181,12 @@ findNearestNeighbourToDiff' editDistance canCompare kdTrees termThing = case ter pure Nothing -- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any), marking both as ineligible for future matches. -findNearestNeighbourTo :: (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. +findNearestNeighbourTo :: (Diff f fields -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared. -> Both.Both (KdTree Double (UnmappedTerm f fields)) -> UnmappedTerm f fields -> State (Int, UnmappedTerms f fields, UnmappedTerms f fields) - (These Int Int, These (Term f (Record fields)) (Term f (Record fields))) + (MappedDiff f fields) findNearestNeighbourTo editDistance canCompare kdTrees term@(UnmappedTerm j _ b) = do (previous, unmappedA, unmappedB) <- get fromMaybe (insertion previous unmappedA unmappedB term) $ do @@ -203,7 +212,7 @@ isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound -- -- cf ยง4.2 of RWS-Diff nearestUnmapped - :: (These (Term f (Record fields)) (Term f (Record fields)) -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. + :: (Diff f fields -> Int) -- ^ A function computes a constant-time approximation to the edit distance between two terms. -> (Term f (Record fields) -> Term f (Record fields) -> Bool) -- ^ A relation determining whether two terms can be compared. -> UnmappedTerms f fields -- ^ A set of terms eligible for matching against. -> KdTree Double (UnmappedTerm f fields) -- ^ The k-d tree to look up nearest neighbours within. @@ -227,43 +236,33 @@ insertion :: Int -> UnmappedTerms f fields -> UnmappedTerm f fields -> State (Int, UnmappedTerms f fields, UnmappedTerms f fields) - (These Int Int, These (Term f (Record fields)) (Term f (Record fields))) + (MappedDiff f fields) insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do put (previous, unmappedA, IntMap.delete j unmappedB) pure (That j, That b) --- genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) => RWSEditScript f fields -> ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)]) --- genFeaturizedTermsAndDiffs sesDiffs = (featurizedAs, featurizedBs, countersAndDiffs, allDiffs) --- where --- (featurizedAs, featurizedBs, _, _, countersAndDiffs, allDiffs) = foldl' (\(as, bs, counterA, counterB, diffs, allDiffs) diff -> --- case diff of --- This term -> --- (as <> pure (featurize counterA term), bs, succ counterA, counterB, diffs, allDiffs <> pure None) --- That term -> --- (as, bs <> pure (featurize counterB term), counterA, succ counterB, diffs, allDiffs <> pure (Term (featurize counterB term))) --- These a b -> --- (as, bs, succ counterA, succ counterB, diffs <> pure (These counterA counterB, These a b), allDiffs <> pure (Index counterA)) --- ) ([], [], 0, 0, [], []) sesDiffs -genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) => RWSEditScript f fields -> State (Int, Int) ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)]) -genFeaturizedTermsAndDiffs sesDiffs = go - where - go = case sesDiffs of - [] -> pure ([], [], [], []) - (diff : diffs) -> do - (counterA, counterB) <- get - case diff of - This term -> do - put (succ counterA, counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (featurize counterA term : as, bs, mappedDiffs, None : allDiffs ) - That term -> do - put (counterA, succ counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (as, featurize counterB term : bs, mappedDiffs, Term (featurize counterB term) : allDiffs) - These a b -> do - put (succ counterA, succ counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (as, bs, (These counterA counterB, These a b) : mappedDiffs, Index counterA : allDiffs) +genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) + => RWSEditScript f fields + -> State + (Int, Int) + ([UnmappedTerm f fields], [UnmappedTerm f fields], [MappedDiff f fields], [TermOrIndexOrNone (UnmappedTerm f fields)]) +genFeaturizedTermsAndDiffs sesDiffs = case sesDiffs of + [] -> pure ([], [], [], []) + (diff : diffs) -> do + (counterA, counterB) <- get + case diff of + This term -> do + put (succ counterA, counterB) + (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs + pure (featurize counterA term : as, bs, mappedDiffs, None : allDiffs ) + That term -> do + put (counterA, succ counterB) + (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs + pure (as, featurize counterB term : bs, mappedDiffs, Term (featurize counterB term) : allDiffs) + These a b -> do + put (succ counterA, succ counterB) + (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs + pure (as, bs, (These counterA counterB, These a b) : mappedDiffs, Index counterA : allDiffs) featurize :: (HasField fields (Maybe FeatureVector), Functor f) => Int -> Term f (Record fields) -> UnmappedTerm f fields featurize index term = UnmappedTerm index (let Just v = getField (extract term) in v) (eraseFeatureVector term)