diff --git a/src/RWS.hs b/src/RWS.hs index 766fbdb7e..da8a97453 100644 --- a/src/RWS.hs +++ b/src/RWS.hs @@ -96,7 +96,7 @@ run :: (Eq1 f, Functor f, HasField fields Category, HasField fields (Maybe Featu run editDistance canCompare as bs = relay pure (\m k -> case m of SES -> k $ ses (gliftEq (==) `on` fmap category) as bs (GenFeaturizedTermsAndDiffs sesDiffs) -> - k $ evalState (genFeaturizedTermsAndDiffs sesDiffs) (0, 0) + k $ genFeaturizedTermsAndDiffs sesDiffs (FindNearestNeighoursToDiff allDiffs featureAs featureBs) -> k $ findNearestNeighboursToDiff editDistance canCompare allDiffs featureAs featureBs (DeleteRemaining allDiffs remainingDiffs) -> @@ -233,24 +233,18 @@ insertion previous unmappedA unmappedB (UnmappedTerm j _ b) = do put (previous, unmappedA, IntMap.delete j unmappedB) pure (That j, That b) -genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) => RWSEditScript f fields -> State (Int, Int) ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)]) -genFeaturizedTermsAndDiffs sesDiffs = case sesDiffs of - [] -> pure ([], [], [], []) - (diff : diffs) -> do - (counterA, counterB) <- get - case diff of - This term -> do - put (succ counterA, counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (as <> pure (featurize counterA term), bs, mappedDiffs, allDiffs <> pure None) - That term -> do - put (counterA, succ counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (as, bs <> pure (featurize counterB term), mappedDiffs, allDiffs <> pure (Term (featurize counterB term))) - These a b -> do - put (succ counterA, succ counterB) - (as, bs, mappedDiffs, allDiffs) <- genFeaturizedTermsAndDiffs diffs - pure (as, bs, mappedDiffs <> pure (These counterA counterB, These a b), allDiffs <> pure (Index counterA)) +genFeaturizedTermsAndDiffs :: (Functor f, HasField fields (Maybe FeatureVector)) => RWSEditScript f fields -> ([UnmappedTerm f fields], [UnmappedTerm f fields], [(These Int Int, These (Term f (Record fields)) (Term f (Record fields)))], [TermOrIndexOrNone (UnmappedTerm f fields)]) +genFeaturizedTermsAndDiffs sesDiffs = pure (featurizedAs, featurizedBs, countersAndDiffs, allDiffs) + where + (featurizedAs, featurizedBs, _, _, countersAndDiffs, allDiffs) = foldl' (\(as, bs, counterA, counterB, diffs, allDiffs) diff -> + case diff of + This term -> + (as <> pure (featurize counterA term), bs, succ counterA, counterB, diffs, allDiffs <> pure None) + That term -> + (as, bs <> pure (featurize counterB term), counterA, succ counterB, diffs, allDiffs <> pure (Term (featurize counterB term))) + These a b -> + (as, bs, succ counterA, succ counterB, diffs <> pure (These counterA counterB, These a b), allDiffs <> pure (Index counterA)) + ) ([], [], 0, 0, [], []) sesDiffs featurize :: (HasField fields (Maybe FeatureVector), Functor f) => Int -> Term f (Record fields) -> UnmappedTerm f fields featurize index term = UnmappedTerm index (let Just v = getField (extract term) in v) (eraseFeatureVector term)