diff --git a/src/Data/RandomWalkSimilarity.hs b/src/Data/RandomWalkSimilarity.hs index c944b3304..58680da49 100644 --- a/src/Data/RandomWalkSimilarity.hs +++ b/src/Data/RandomWalkSimilarity.hs @@ -245,7 +245,7 @@ defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel default featureVectorDecorator :: (Hashable label, Traversable f) => Label f fields label -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (Vector.Vector Double ': fields)) featureVectorDecorator getLabel p q d = cata (\ ((gram :. rest) :< functor) -> - cofree ((foldr (Vector.zipWith (+) . getField . extract) (unitVector d (hash gram)) functor :. rest) :< functor)) + cofree ((foldr (Vector.zipWith (+) . rhead . extract) (unitVector d (hash gram)) functor :. rest) :< functor)) . pqGramDecorator getLabel p q -- | Annotates a term with the corresponding p,q-gram at each node. @@ -277,7 +277,7 @@ pqGramDecorator getLabel p q = cata algebra -- | Computes a unit vector of the specified dimension from a hash. unitVector :: Int -> Int -> Vector.Vector Double -unitVector d hash = normalize ((`evalRand` mkQCGen hash) (sequenceA (Vector.replicate d getRandom))) +unitVector d hash = normalize ((`evalRand` mkQCGen hash) (Vector.fromList . take d <$> getRandoms)) where normalize vec = fmap (/ vmagnitude vec) vec vmagnitude = sqrtDouble . Vector.sum . fmap (** 2)