diff --git a/src/Data/RandomWalkSimilarity.hs b/src/Data/RandomWalkSimilarity.hs index 2d1ecba6f..0a7971b45 100644 --- a/src/Data/RandomWalkSimilarity.hs +++ b/src/Data/RandomWalkSimilarity.hs @@ -56,8 +56,8 @@ windowed n f seed = para alg type Bag = DList.DList -featureVector :: Hashable label => Bag (Gram label) -> Int -> Vector.Vector Double -featureVector bag d = sumVectors $ unitDVector . hash <$> bag +featureVector :: Hashable label => Int -> Bag (Gram label) -> Vector.Vector Double +featureVector d bag = sumVectors $ unitDVector . hash <$> bag where unitDVector hash = normalize . (`evalRand` mkQCGen hash) $ Prologue.sequence (Vector.replicate d getRandom) normalize vec = fmap (/ vmagnitude vec) vec sumVectors = DList.foldr (Vector.zipWith (+)) (Vector.replicate d 0) diff --git a/test/Data/RandomWalkSimilarity/Spec.hs b/test/Data/RandomWalkSimilarity/Spec.hs index 9a6732384..fdb886fe1 100644 --- a/test/Data/RandomWalkSimilarity/Spec.hs +++ b/test/Data/RandomWalkSimilarity/Spec.hs @@ -23,4 +23,4 @@ spec = parallel $ do describe "featureVector" $ do prop "produces a vector of the specified dimension" . forAll (arbitrary `suchThat` ((> 0) . snd)) $ - \ (grams, d) -> length (featureVector (fromList (grams :: [Gram String])) d) `shouldBe` d + \ (grams, d) -> length (featureVector d (fromList (grams :: [Gram String]))) `shouldBe` d