diff --git a/src/Data/RandomWalkSimilarity.hs b/src/Data/RandomWalkSimilarity.hs index c2c01b848..de578aa2c 100644 --- a/src/Data/RandomWalkSimilarity.hs +++ b/src/Data/RandomWalkSimilarity.hs @@ -21,7 +21,6 @@ import Control.Monad.State import Data.Align.Generic import Data.Array import Data.Functor.Both hiding (fst, snd) -import Data.Functor.Foldable import Data.Functor.Listable import Data.Hashable import qualified Data.IntMap as IntMap @@ -244,36 +243,18 @@ defaultFeatureVectorDecorator -> Term f (Record (FeatureVector ': fields)) defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD -type VState = State (IntMap.IntMap FeatureVector) - -- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions. -featureVectorDecorator :: forall f label fields . (Hashable label, Traversable f) => Label f fields label -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (FeatureVector ': fields)) +featureVectorDecorator :: (Hashable label, Traversable f) => Label f fields label -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (FeatureVector ': fields)) featureVectorDecorator getLabel p q d - = (`evalState` IntMap.empty) . cata collect . pqGramDecorator getLabel p q - where collect :: CofreeF f (Record (Gram label ': fields)) (VState (Term f (Record (FeatureVector ': fields)))) -> VState (Term f (Record (FeatureVector ': fields))) - collect ((gram :. rest) :< functorState) = do - featureVector <- foldl' addSubtermVector (unitVector' d (hash gram)) functorState - functor <- sequenceA functorState - pure $! cofree ((featureVector :. rest) :< functor) - addSubtermVector :: VState FeatureVector -> VState (Term f (Record (FeatureVector ': fields))) -> VState FeatureVector - addSubtermVector accumState termState = do - accum <- accumState - term <- termState - pure $! addVectors accum (rhead (extract term)) + = cata collect + . pqGramDecorator getLabel p q + where collect ((gram :. rest) :< functor) = cofree ((foldl' addSubtermVector (unitVector d (hash gram)) functor :. rest) :< functor) + addSubtermVector :: FeatureVector -> Term f (Record (FeatureVector ': fields)) -> FeatureVector + addSubtermVector = flip $ addVectors . rhead . headF . runCofree addVectors :: Num a => Array Int a -> Array Int a -> Array Int a addVectors as bs = listArray (0, d - 1) (fmap (\ i -> as ! i + bs ! i) [0..(d - 1)]) -unitVector' :: Int -> Int -> State (IntMap.IntMap FeatureVector) FeatureVector -unitVector' d hash = do - map <- get - case IntMap.lookup hash map of - Just v -> pure v - _ -> do - let v = unitVector d hash - put (IntMap.insert hash v map) - pure v - -- | Annotates a term with the corresponding p,q-gram at each node. pqGramDecorator :: Traversable f