mirror of
https://github.com/github/semantic.git
synced 2025-01-08 08:30:27 +03:00
Replace cofree with Term and TermF
This commit is contained in:
parent
86a38579cd
commit
f026a2d2f1
@ -25,7 +25,7 @@ import Data.Record
|
||||
import qualified Data.Vector as Vector
|
||||
import Patch
|
||||
import Prologue as P
|
||||
import Term (termSize, zipTerms, Term)
|
||||
import Term (termSize, zipTerms, Term, TermF)
|
||||
import Test.QuickCheck hiding (Fixed)
|
||||
import Test.QuickCheck.Random
|
||||
import qualified SES
|
||||
@ -43,15 +43,12 @@ rws :: forall f fields.
|
||||
(GAlign f,
|
||||
Traversable f,
|
||||
HasField fields Category,
|
||||
Eq (f (Cofree f Category)),
|
||||
HasField fields (Vector.Vector Double)) =>
|
||||
(Term f (Record fields)
|
||||
-> Term f (Record fields)
|
||||
-> Maybe (Diff f (Record fields))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
|
||||
-- | The list of old terms.
|
||||
-> [Cofree f (Record fields)] -- ^ The list of new terms.
|
||||
-> [Cofree f (Record fields)] -- ^ The resulting list of similarity-matched diffs.
|
||||
-> [Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))]
|
||||
Eq (f (Term f Category)),
|
||||
HasField fields (Vector.Vector Double))
|
||||
=> (Term f (Record fields) -> Term f (Record fields) -> Maybe (Diff f (Record fields))) -- ^ A function which compares a pair of terms recursively, returning 'Just' their diffed value if appropriate, or 'Nothing' if they should not be compared.
|
||||
-> [Term f (Record fields)] -- ^ The list of old terms.
|
||||
-> [Term f (Record fields)] -- ^ The list of new terms.
|
||||
-> [Diff f (Record fields)] -- ^ The resulting list of similarity-matched diffs.
|
||||
rws compare as bs
|
||||
| null as, null bs = []
|
||||
| null as = inserting <$> bs
|
||||
@ -62,28 +59,17 @@ rws compare as bs
|
||||
traverse findNearestNeighbourToDiff allDiffs &
|
||||
fmap catMaybes &
|
||||
-- Run the state with an initial state
|
||||
(`runState` (pred $ maybe 0 getMin (getOption (foldMap (Option . Just . Min . termIndex) fas)),
|
||||
toMap fas,
|
||||
toMap fbs)) &
|
||||
(`runState` (pred $ maybe 0 getMin (getOption (foldMap (Option . Just . Min . termIndex) featurizedAs)),
|
||||
toMap featurizedAs,
|
||||
toMap featurizedBs)) &
|
||||
uncurry deleteRemaining &
|
||||
insertMapped countersAndDiffs &
|
||||
fmap snd
|
||||
|
||||
where
|
||||
sesDiffs = eitherCutoff 1 <$> SES.ses replaceIfEqual cost as bs
|
||||
replaceIfEqual :: HasField fields Category => Cofree f (Record fields) -> Cofree f (Record fields) -> Maybe (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields))))
|
||||
replaceIfEqual a b
|
||||
| (category <$> a) == (category <$> b) = hylo wrap runCofree <$> zipTerms a b
|
||||
| otherwise = Nothing
|
||||
cost = iter (const 0) . (1 <$)
|
||||
|
||||
eitherCutoff :: (Functor f) => Integer
|
||||
-> Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))
|
||||
-> Free (CofreeF f (Both (Record fields))) (Either (Free (CofreeF f (Both (Record fields))) (Patch (Cofree f (Record fields)))) (Patch (Cofree f (Record fields))))
|
||||
eitherCutoff n diff | n <= 0 = pure (Left diff)
|
||||
eitherCutoff n diff = free . bimap Right (eitherCutoff (pred n)) $ runFree diff
|
||||
|
||||
(fas, fbs, _, _, countersAndDiffs, allDiffs) =
|
||||
(featurizedAs, featurizedBs, _, _, countersAndDiffs, allDiffs) =
|
||||
foldl' (\(as, bs, counterA, counterB, diffs, allDiffs) diff -> case runFree diff of
|
||||
Pure (Right (Delete term)) ->
|
||||
(as <> pure (featurize counterA term), bs, succ counterA, counterB, diffs, allDiffs <> pure Nil)
|
||||
@ -104,13 +90,6 @@ rws compare as bs
|
||||
put (i, unA, unB)
|
||||
pure Nothing
|
||||
|
||||
kdas = KdTree.build (Vector.toList . feature) fas
|
||||
kdbs = KdTree.build (Vector.toList . feature) fbs
|
||||
|
||||
featurize index term = UnmappedTerm index (getField (extract term)) term
|
||||
|
||||
toMap = IntMap.fromList . fmap (termIndex &&& identity)
|
||||
|
||||
-- | Construct a diff for a term in B by matching it against the most similar eligible term in A (if any),
|
||||
-- marking both as ineligible for future matches.
|
||||
findNearestNeighbourTo :: UnmappedTerm f fields
|
||||
@ -156,8 +135,6 @@ rws compare as bs
|
||||
-> Maybe (UnmappedTerm f fields) -- ^ The most similar unmapped term, if any.
|
||||
nearestUnmapped unmapped tree key = getFirst $ foldMap (First . Just) (sortOn (maybe maxBound (editDistanceUpTo defaultM) . compare (term key) . term) (toList (IntMap.intersection unmapped (toMap (KdTree.kNearest tree defaultL key)))))
|
||||
|
||||
-- | Determines whether an index is in-bounds for a move given the most recently matched index.
|
||||
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
|
||||
insertMapped diffs into = foldl' (\into (i, mappedTerm) ->
|
||||
insertDiff (i, mappedTerm) into)
|
||||
into
|
||||
@ -169,7 +146,36 @@ rws compare as bs
|
||||
diffs
|
||||
((termIndex &&& deleting . term) <$> unmappedA)
|
||||
|
||||
insertDiff :: (These Int Int, a) -> [(These Int Int, a)] -> [(These Int Int, a)]
|
||||
replaceIfEqual :: HasField fields Category
|
||||
=> Term f (Record fields)
|
||||
-> Term f (Record fields)
|
||||
-> Maybe (Diff f (Record fields))
|
||||
replaceIfEqual a b
|
||||
| (category <$> a) == (category <$> b) = hylo wrap runCofree <$> zipTerms a b
|
||||
| otherwise = Nothing
|
||||
|
||||
cost = iter (const 0) . (1 <$)
|
||||
|
||||
eitherCutoff :: (Functor f) => Integer
|
||||
-> Diff f (Record fields)
|
||||
-> Free (TermF f (Both (Record fields)))
|
||||
(Either (Diff f (Record fields)) (Patch (Term f (Record fields))))
|
||||
eitherCutoff n diff | n <= 0 = pure (Left diff)
|
||||
eitherCutoff n diff = free . bimap Right (eitherCutoff (pred n)) $ runFree diff
|
||||
|
||||
kdas = KdTree.build (Vector.toList . feature) featurizedAs
|
||||
kdbs = KdTree.build (Vector.toList . feature) featurizedBs
|
||||
|
||||
featurize index term = UnmappedTerm index (getField (extract term)) term
|
||||
|
||||
toMap = IntMap.fromList . fmap (termIndex &&& identity)
|
||||
|
||||
|
||||
-- | Determines whether an index is in-bounds for a move given the most recently matched index.
|
||||
isInMoveBounds previous i = previous < i && i < previous + defaultMoveBound
|
||||
|
||||
-- | Inserts an index and diff pair into a list of indices and diffs.
|
||||
insertDiff :: (These Int Int, diff) -> [(These Int Int, diff)] -> [(These Int Int, diff)]
|
||||
insertDiff inserted [] = [ inserted ]
|
||||
insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of
|
||||
(These i1 i2, These j1 j2) -> if i1 <= j1 && i2 <= j2 then a : b : rest else b : insertDiff a rest
|
||||
@ -195,7 +201,7 @@ insertDiff a@(ij1, _) (b@(ij2, _):rest) = case (ij1, ij2) of
|
||||
-- | Return an edit distance as the sum of it's term sizes, given an cutoff and a syntax of terms 'f a'.
|
||||
-- | Computes a constant-time approximation to the edit distance of a diff. This is done by comparing at most _m_ nodes, & assuming the rest are zero-cost.
|
||||
editDistanceUpTo :: (P.Foldable f, Functor f) => Integer
|
||||
-> Free (CofreeF f (Both a)) (Patch (Cofree f a))
|
||||
-> Diff f annotation
|
||||
-> Int
|
||||
editDistanceUpTo m = diffSum (patchSum termSize) . cutoff m
|
||||
where diffSum patchCost = sum . fmap (maybe 0 patchCost)
|
||||
@ -213,7 +219,7 @@ defaultM :: Integer
|
||||
defaultM = 10
|
||||
|
||||
-- | A term which has not yet been mapped by `rws`, along with its feature vector summary & index.
|
||||
data UnmappedTerm f fields = UnmappedTerm { termIndex :: {-# UNPACK #-} !Int, feature :: !(Vector.Vector Double), term :: !(Cofree f (Record fields)) }
|
||||
data UnmappedTerm f fields = UnmappedTerm { termIndex :: {-# UNPACK #-} !Int, feature :: !(Vector.Vector Double), term :: !(Term f (Record fields)) }
|
||||
|
||||
-- | Either a `term`, an index of a matched term, or nil.
|
||||
data TermOrIndexOrNil term = Term term | Index Int | Nil
|
||||
@ -226,11 +232,14 @@ data Gram label = Gram { stem :: [Maybe label], base :: [Maybe label] }
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | Annotates a term with a feature vector at each node, using the default values for the p, q, and d parameters.
|
||||
defaultFeatureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields))
|
||||
defaultFeatureVectorDecorator :: (Hashable label, Traversable f) =>
|
||||
(forall b. TermF f (Record fields) b -> label)
|
||||
-> Term f (Record fields)
|
||||
-> Term f (Record (Vector.Vector Double ': fields))
|
||||
defaultFeatureVectorDecorator getLabel = featureVectorDecorator getLabel defaultP defaultQ defaultD
|
||||
|
||||
-- | Annotates a term with a feature vector at each node, parameterized by stem length, base width, and feature vector dimensions.
|
||||
featureVectorDecorator :: (Hashable label, Traversable f) => (forall b. CofreeF f (Record fields) b -> label) -> Int -> Int -> Int -> Cofree f (Record fields) -> Cofree f (Record (Vector.Vector Double ': fields))
|
||||
featureVectorDecorator :: (Hashable label, Traversable f) => (forall b. TermF f (Record fields) b -> label) -> Int -> Int -> Int -> Term f (Record fields) -> Term f (Record (Vector.Vector Double ': fields))
|
||||
featureVectorDecorator getLabel p q d
|
||||
= cata (\ (RCons gram rest :< functor) ->
|
||||
cofree ((foldr (Vector.zipWith (+) . getField . extract) (unitVector d (hash gram)) functor .: rest) :< functor))
|
||||
@ -238,11 +247,11 @@ featureVectorDecorator getLabel p q d
|
||||
|
||||
-- | Annotates a term with the corresponding p,q-gram at each node.
|
||||
pqGramDecorator :: Traversable f
|
||||
=> (forall b. CofreeF f (Record fields) b -> label) -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functor’s constructor, but not any recursive values inside the functor (since they’re held parametric in 'b').
|
||||
-> Int -- ^ 'p'; the desired stem length for the grams.
|
||||
-> Int -- ^ 'q'; the desired base length for the grams.
|
||||
-> Cofree f (Record fields) -- ^ The term to decorate.
|
||||
-> Cofree f (Record (Gram label ': fields)) -- ^ The decorated term.
|
||||
=> (forall b. TermF f (Record fields) b -> label) -- ^ A function computing the label from an arbitrary unpacked term. This function can use the annotation and functor’s constructor, but not any recursive values inside the functor (since they’re held parametric in 'b').
|
||||
-> Int -- ^ 'p'; the desired stem length for the grams.
|
||||
-> Int -- ^ 'q'; the desired base length for the grams.
|
||||
-> Term f (Record fields) -- ^ The term to decorate.
|
||||
-> Term f (Record (Gram label ': fields)) -- ^ The decorated term.
|
||||
pqGramDecorator getLabel p q = cata algebra
|
||||
where
|
||||
algebra term = let label = getLabel term in
|
||||
@ -250,13 +259,15 @@ pqGramDecorator getLabel p q = cata algebra
|
||||
gram label = Gram (padToSize p []) (padToSize q (pure (Just label)))
|
||||
assignParentAndSiblingLabels functor label = (`evalState` (replicate (q `div` 2) Nothing <> siblingLabels functor)) (for functor (assignLabels label))
|
||||
|
||||
assignLabels :: label -> Cofree f (Record (Gram label ': fields)) -> State [Maybe label] (Cofree f (Record (Gram label ': fields)))
|
||||
assignLabels :: label
|
||||
-> Term f (Record (Gram label ': fields))
|
||||
-> State [Maybe label] (Term f (Record (Gram label ': fields)))
|
||||
assignLabels label a = case runCofree a of
|
||||
RCons gram rest :< functor -> do
|
||||
labels <- get
|
||||
put (drop 1 labels)
|
||||
pure $! cofree ((gram { stem = padToSize p (Just label : stem gram), base = padToSize q labels } .: rest) :< functor)
|
||||
siblingLabels :: Traversable f => f (Cofree f (Record (Gram label ': fields))) -> [Maybe label]
|
||||
siblingLabels :: Traversable f => f (Term f (Record (Gram label ': fields))) -> [Maybe label]
|
||||
siblingLabels = foldMap (base . rhead . extract)
|
||||
padToSize n list = take n (list <> repeat empty)
|
||||
|
||||
@ -268,11 +279,13 @@ unitVector d hash = normalize ((`evalRand` mkQCGen hash) (sequenceA (Vector.repl
|
||||
vmagnitude = sqrtDouble . Vector.sum . fmap (** 2)
|
||||
|
||||
-- | Strips the head annotation off a term annotated with non-empty records.
|
||||
stripTerm :: Functor f => Cofree f (Record (h ': t)) -> Cofree f (Record t)
|
||||
stripTerm :: Functor f => Term f (Record (h ': t)) -> Term f (Record t)
|
||||
stripTerm = fmap rtail
|
||||
|
||||
-- | Strips the head annotation off a diff annotated with non-empty records.
|
||||
stripDiff :: (Functor f, Functor g) => Free (CofreeF f (g (Record (h ': t)))) (Patch (Cofree f (Record (h ': t)))) -> Free (CofreeF f (g (Record t))) (Patch (Cofree f (Record t)))
|
||||
stripDiff :: (Functor f, Functor g)
|
||||
=> Free (TermF f (g (Record (h ': t)))) (Patch (Term f (Record (h ': t))))
|
||||
-> Free (TermF f (g (Record t))) (Patch (Term f (Record t)))
|
||||
stripDiff = iter (\ (h :< f) -> wrap (fmap rtail h :< f)) . fmap (pure . fmap stripTerm)
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
module Term.Arbitrary where
|
||||
|
||||
import Data.Functor.Foldable (Base, cata, unfold, Unfoldable(embed))
|
||||
import Data.Functor.Foldable (Base, unfold, Unfoldable(embed))
|
||||
import Data.Text.Arbitrary ()
|
||||
import Prologue
|
||||
import Syntax
|
||||
|
Loading…
Reference in New Issue
Block a user