diff --git a/src/Data/RandomWalkSimilarity.hs b/src/Data/RandomWalkSimilarity.hs index 1fd90953e..290ef17f4 100644 --- a/src/Data/RandomWalkSimilarity.hs +++ b/src/Data/RandomWalkSimilarity.hs @@ -22,16 +22,17 @@ rws compare getLabel as bs | null as, null bs = [] | null as = insert <$> bs | null bs = delete <$> as - | otherwise = (`evalState` Set.empty) $ traverse findNearestNeighbourTo (featurize <$> bs) + | otherwise = uncurry deleteRemaining . (`runState` Set.empty) $ traverse findNearestNeighbourTo (featurize <$> bs) where insert = pure . Insert delete = pure . Delete (p, q) = (2, 2) d = 15 - fas = KdTree.build (Vector.toList . fst) (featurize <$> as) + fas = featurize <$> as + kdas = KdTree.build (Vector.toList . fst) fas featurize = featureVector d . pqGrams p q getLabel &&& identity findNearestNeighbourTo kv@(_, v) = do mapped <- get - let (k, nearest) = (KdTree.nearest fas kv) + let (k, nearest) = (KdTree.nearest kdas kv) if k `Set.member` mapped then pure $! insert v else case compare nearest v of @@ -39,6 +40,7 @@ rws compare getLabel as bs put (Set.insert k mapped) pure y _ -> pure $! delete v + deleteRemaining diff mapped = diff <> (delete . snd <$> filter ((`Set.member` mapped) . fst) fas) data Gram label = Gram { stem :: [Maybe label], base :: [Maybe label] } deriving (Eq, Show)