diff --git a/semantic-diff.cabal b/semantic-diff.cabal index 63d92519d..fd94cbf0b 100644 --- a/semantic-diff.cabal +++ b/semantic-diff.cabal @@ -66,6 +66,7 @@ library , QuickCheck >= 2.8.1 , quickcheck-text , semigroups + , syb , text >= 1.2.1.3 , text-icu , these diff --git a/src/Term.hs b/src/Term.hs index 8aa956f91..0c8eb41a8 100644 --- a/src/Term.hs +++ b/src/Term.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE RankNTypes, TypeFamilies, TypeSynonymInstances #-} +{-# LANGUAGE ScopedTypeVariables, RankNTypes, TypeFamilies, TypeSynonymInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Term where @@ -9,6 +9,9 @@ import Data.Functor.Both import Data.OrderedMap hiding (size) import Data.These import Syntax +import Data.Data +import Data.Generics.Twins +import Unsafe.Coerce -- | An annotated node (Syntax) in an abstract syntax tree. type TermF a annotation = CofreeF (Syntax a) annotation @@ -50,3 +53,21 @@ alignSyntax' a b = case (a, b) of (Fixed a, Fixed b) -> Just (Fixed (align a b)) (Keyed a, Keyed b) -> Just (Keyed (align a b)) _ -> Nothing + +alignF :: (Data (f a), Data (f b), Data (f (These a b)), Typeable a, Typeable b) => f a -> f b -> Maybe (f b) +alignF a b = do + guard (toConstr a == toConstr b) + alignM a b + where alignM :: (Data a, Data b, Alternative m, Monad m) => a -> b -> m b + alignM a b = gzipWithM go a b + where go :: forall m a b. (Data a, Data b, Alternative m, Monad m) => a -> b -> m b + go a b = do + guard (toConstr a == toConstr b) + fromConstrM (do + b' <- guardCast b + alignM a b') (toConstr b) + + guardCast :: forall f a b. (Typeable a, Typeable b, Alternative f) => a -> f b + guardCast a = + guard (typeRep (Proxy :: Proxy a) == typeRep (Proxy :: Proxy b)) + *> unsafeCoerce a