diff --git a/src/Data/RandomWalkSimilarity.hs b/src/Data/RandomWalkSimilarity.hs index 9487da71f..bcaa59bfd 100644 --- a/src/Data/RandomWalkSimilarity.hs +++ b/src/Data/RandomWalkSimilarity.hs @@ -2,6 +2,7 @@ module Data.RandomWalkSimilarity ( rws , pqGramDecorator +, defaultFeatureVectorDecorator , featureVectorDecorator , editDistanceUpTo , defaultD @@ -125,13 +126,17 @@ unitVector d hash = normalize ((`evalRand` mkQCGen hash) (sequenceA (Vector.repl where normalize vec = fmap (/ vmagnitude vec) vec vmagnitude = sqrtDouble . Vector.sum . fmap (** 2) --- | Annotates a term with a feature vector at each node. +-- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions. featureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Int -> Int -> Int -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields)) featureVectorDecorator getLabel p q d = cata (\ (RCons gram rest :< functor) -> cofree ((foldr (Vector.zipWith (+) . getField . extract) (unitVector d (hash gram)) functor .: rest) :< functor)) . pqGramDecorator getLabel p q +-- | Annotates a term with a feature vector at each node, using the default values for the p, q, and d parameters. +defaultFeatureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields)) +defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD + -- | Strips the head annotation off a term annotated with non-empty records. stripTerm :: Functor f => Cofree f (Record (h ': t)) -> Cofree f (Record t) stripTerm = fmap rtail