1
1
mirror of https://github.com/github/semantic.git synced 2024-11-28 10:15:55 +03:00

add contents to Unmarshal

This commit is contained in:
Ayman Nadeem 2020-01-27 19:21:55 -05:00
parent 10a11b41ad
commit 24bcd26b75

View File

@ -1 +1,405 @@
module Unmarshal () where {-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Unmarshal
( parseByteString
, UnmarshalState(..)
, UnmarshalError(..)
, FieldName(..)
, Unmarshal(..)
, UnmarshalAnn(..)
, UnmarshalField(..)
, SymbolMatching(..)
, Match(..)
, hoist
, lookupSymbol
, unmarshalNode
, GHasAnn(..)
) where
import Control.Algebra (send)
import Control.Carrier.Reader hiding (asks)
import Control.Exception
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Coerce
import Data.Foldable (toList)
import qualified Data.IntMap as IntMap
import Data.List.NonEmpty (NonEmpty (..))
import Data.Proxy
import qualified Data.Text as Text
import Data.Text.Encoding
import Data.Text.Encoding.Error (lenientDecode)
import Foreign.C.String
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import GHC.Records
import GHC.TypeLits
import Source.Loc
import Source.Span
import TreeSitter.Cursor as TS
import TreeSitter.Language as TS
import TreeSitter.Node as TS
import TreeSitter.Parser as TS
import AST.Token as TS
import TreeSitter.Tree as TS
asks :: Has (Reader r) sig m => (r -> r') -> m r'
asks f = send (Ask (pure . f))
{-# INLINE asks #-}
-- Parse source code and produce AST
parseByteString :: (Unmarshal t, UnmarshalAnn a) => Ptr TS.Language -> ByteString -> IO (Either String (t a))
parseByteString language bytestring = withParser language $ \ parser -> withParseTree parser bytestring $ \ treePtr ->
if treePtr == nullPtr then
pure (Left "error: didn't get a root node")
else
withRootNode treePtr $ \ rootPtr ->
withCursor (castPtr rootPtr) $ \ cursor ->
(Right <$> runReader (UnmarshalState bytestring cursor) (liftIO (peek rootPtr) >>= unmarshalNode))
`catch` (pure . Left . getUnmarshalError)
newtype UnmarshalError = UnmarshalError { getUnmarshalError :: String }
deriving (Show)
instance Exception UnmarshalError
data UnmarshalState = UnmarshalState
{ source :: {-# UNPACK #-} !ByteString
, cursor :: {-# UNPACK #-} !(Ptr Cursor)
}
type MatchM = ReaderC UnmarshalState IO
newtype Match t = Match
{ runMatch :: forall a . UnmarshalAnn a => Node -> MatchM (t a)
}
-- | A church-encoded binary tree with constant-time 'singleton', 'mempty', '<>', and 'fmap', and linear-time 'foldMap'.
newtype B a = B (forall r . (r -> r -> r) -> (a -> r) -> r -> r)
instance Functor B where
fmap f (B run) = B (\ fork leaf -> run fork (leaf . f))
{-# INLINE fmap #-}
a <$ B run = B (\ fork leaf -> run fork (leaf . const a))
{-# INLINE (<$) #-}
instance Semigroup (B a) where
B l <> B r = B (\ fork leaf nil -> fork (l fork leaf nil) (r fork leaf nil))
{-# INLINE (<>) #-}
instance Monoid (B a) where
mempty = B (\ _ _ nil -> nil)
{-# INLINE mempty #-}
instance Foldable B where
foldMap f (B run) = run (<>) f mempty
{-# INLINE foldMap #-}
singleton :: a -> B a
singleton a = B (\ _ leaf _ -> leaf a)
{-# INLINE singleton #-}
hoist :: (forall x . t x -> t' x) -> Match t -> Match t'
hoist f (Match run) = Match (fmap f . run)
{-# INLINE hoist #-}
lookupSymbol :: TSSymbol -> IntMap.IntMap a -> Maybe a
lookupSymbol sym map = IntMap.lookup (fromIntegral sym) map
{-# INLINE lookupSymbol #-}
-- | Unmarshal a node
unmarshalNode :: forall t a .
( UnmarshalAnn a
, Unmarshal t
)
=> Node
-> MatchM (t a)
unmarshalNode node = case lookupSymbol (nodeSymbol node) matchers' of
Just t -> runMatch t node
Nothing -> liftIO . throwIO . UnmarshalError $ showFailure (Proxy @t) node
{-# INLINE unmarshalNode #-}
-- | Unmarshalling is the process of iterating over tree-sitters parse trees using its tree cursor API and producing Haskell ASTs for the relevant nodes.
--
-- Datatypes which can be constructed from tree-sitter parse trees may use the default definition of 'matchers' providing that they have a suitable 'Generic1' instance.
class SymbolMatching t => Unmarshal t where
matchers' :: IntMap.IntMap (Match t)
matchers' = IntMap.fromList (toList matchers)
matchers :: B (Int, Match t)
default matchers :: (Generic1 t, GUnmarshal (Rep1 t)) => B (Int, Match t)
matchers = foldMap (singleton . (, match)) (matchedSymbols (Proxy @t))
where match = Match $ \ node -> do
cursor <- asks cursor
goto cursor (nodeTSNode node)
fmap to1 (gunmarshalNode node)
instance (Unmarshal f, Unmarshal g) => Unmarshal (f :+: g) where
matchers = fmap (fmap (hoist L1)) matchers <> fmap (fmap (hoist R1)) matchers
instance Unmarshal t => Unmarshal (Rec1 t) where
matchers = coerce (matchers @t)
instance (KnownNat n, KnownSymbol sym) => Unmarshal (Token sym n) where
matchers = singleton (fromIntegral (natVal (Proxy @n)), Match (fmap Token . unmarshalAnn))
-- | Unmarshal an annotation field.
--
-- Leaf nodes have 'Text.Text' fields, and leaves, anonymous leaves, and products all have parametric annotation fields. All of these fields are unmarshalled using the metadata of the node, e.g. its start/end bytes, without reference to any child nodes it may contain.
class UnmarshalAnn a where
unmarshalAnn
:: Node
-> MatchM a
instance UnmarshalAnn () where
unmarshalAnn _ = pure ()
instance UnmarshalAnn Text.Text where
unmarshalAnn node = do
range <- unmarshalAnn node
asks (decodeUtf8With lenientDecode . slice range . source)
-- | Instance for pairs of annotations
instance (UnmarshalAnn a, UnmarshalAnn b) => UnmarshalAnn (a,b) where
unmarshalAnn node = (,)
<$> unmarshalAnn @a node
<*> unmarshalAnn @b node
instance UnmarshalAnn Loc where
unmarshalAnn node = Loc
<$> unmarshalAnn @Range node
<*> unmarshalAnn @Span node
instance UnmarshalAnn Range where
unmarshalAnn node = do
let start = fromIntegral (nodeStartByte node)
end = fromIntegral (nodeEndByte node)
pure (Range start end)
instance UnmarshalAnn Span where
unmarshalAnn node = do
let spanStart = pointToPos (nodeStartPoint node)
spanEnd = pointToPos (nodeEndPoint node)
pure (Span spanStart spanEnd)
pointToPos :: TSPoint -> Pos
pointToPos (TSPoint line column) = Pos (fromIntegral line) (fromIntegral column)
-- | Optional/repeated fields occurring in product datatypes are wrapped in type constructors, e.g. 'Maybe', '[]', or 'NonEmpty', and thus can unmarshal zero or more nodes for the same field name.
class UnmarshalField t where
unmarshalField
:: ( Unmarshal f
, UnmarshalAnn a
)
=> String -- ^ datatype name
-> String -- ^ field name
-> [Node] -- ^ nodes
-> MatchM (t (f a))
instance UnmarshalField Maybe where
unmarshalField _ _ [] = pure Nothing
unmarshalField _ _ [x] = Just <$> unmarshalNode x
unmarshalField d f _ = liftIO . throwIO . UnmarshalError $ "type '" <> d <> "' expected zero or one nodes in field '" <> f <> "' but got multiple"
instance UnmarshalField [] where
unmarshalField d f (x:xs) = do
head' <- unmarshalNode x
tail' <- unmarshalField d f xs
pure $ head' : tail'
unmarshalField _ _ [] = pure []
instance UnmarshalField NonEmpty where
unmarshalField d f (x:xs) = do
head' <- unmarshalNode x
tail' <- unmarshalField d f xs
pure $ head' :| tail'
unmarshalField d f [] = liftIO . throwIO . UnmarshalError $ "type '" <> d <> "' expected one or more nodes in field '" <> f <> "' but got zero"
class SymbolMatching (a :: * -> *) where
matchedSymbols :: Proxy a -> [Int]
-- | Provide error message describing the node symbol vs. the symbols this can match
showFailure :: Proxy a -> Node -> String
instance SymbolMatching f => SymbolMatching (M1 i c f) where
matchedSymbols _ = matchedSymbols (Proxy @f)
showFailure _ = showFailure (Proxy @f)
instance SymbolMatching f => SymbolMatching (Rec1 f) where
matchedSymbols _ = matchedSymbols (Proxy @f)
showFailure _ = showFailure (Proxy @f)
instance (KnownNat n, KnownSymbol sym) => SymbolMatching (Token sym n) where
matchedSymbols _ = [fromIntegral (natVal (Proxy @n))]
showFailure _ _ = "expected " ++ symbolVal (Proxy @sym)
instance (SymbolMatching f, SymbolMatching g) => SymbolMatching (f :+: g) where
matchedSymbols _ = matchedSymbols (Proxy @f) <> matchedSymbols (Proxy @g)
showFailure _ = sep <$> showFailure (Proxy @f) <*> showFailure (Proxy @g)
sep :: String -> String -> String
sep a b = a ++ ". " ++ b
-- | Move the cursor to point at the passed 'TSNode'.
goto :: Ptr Cursor -> TSNode -> MatchM ()
goto cursor node = liftIO (with node (ts_tree_cursor_reset_p cursor))
type Fields = [(FieldName, Node)]
-- | Return the fields remaining in the current branch, represented as 'Map.Map' of 'FieldName's to their corresponding 'Node's.
getFields :: Ptr Cursor -> Node -> MatchM Fields
getFields cursor node
| maxCount == 0 = pure []
| otherwise = do
nodes <- liftIO . allocaArray maxCount $ \ ptr -> do
actualCount <- ts_tree_cursor_copy_child_nodes cursor ptr
peekArray (fromIntegral actualCount) ptr
traverse (\ node -> (, node) <$> getFieldName node) nodes
where
maxCount = fromIntegral (nodeChildCount node)
getFieldName node
| nodeFieldName node == nullPtr = pure (FieldName "extraChildren")
| otherwise = FieldName . toHaskellCamelCaseIdentifier <$> liftIO (peekCString (nodeFieldName node))
lookupField :: FieldName -> Fields -> [Node]
lookupField k = map snd . filter ((== k) . fst)
-- | Return a 'ByteString' that contains a slice of the given 'ByteString'.
slice :: Range -> ByteString -> ByteString
slice (Range start end) = take . drop
where drop = B.drop start
take = B.take (end - start)
newtype FieldName = FieldName { getFieldName :: String }
deriving (Eq, Ord, Show)
-- | Generic construction of ASTs from a 'Map.Map' of named fields.
--
-- Product types (specifically, record types) are constructed by looking up the node for each corresponding field name in the map, moving the cursor to it, and then invoking 'unmarshalNode' to construct the value for that field. Leaf types are constructed as a special case of product types.
--
-- Sum types are constructed by using the current nodes symbol to select the corresponding constructor deterministically.
class GUnmarshal f where
gunmarshalNode
:: UnmarshalAnn a
=> Node
-> MatchM (f a)
instance (Datatype d, GUnmarshalData f) => GUnmarshal (M1 D d f) where
gunmarshalNode = go (gunmarshalNode' (datatypeName @d undefined)) where
go :: (Node -> MatchM (f a)) -> Node -> MatchM (M1 i c f a)
go = coerce
class GUnmarshalData f where
gunmarshalNode'
:: UnmarshalAnn a
=> String
-> Node
-> MatchM (f a)
instance GUnmarshalData f => GUnmarshalData (M1 i c f) where
gunmarshalNode' = go gunmarshalNode' where
go :: (String -> Node -> MatchM (f a)) -> String -> Node -> MatchM (M1 i c f a)
go = coerce
-- For anonymous leaf nodes:
instance GUnmarshalData U1 where
gunmarshalNode' _ _ = pure U1
-- For unary products:
instance UnmarshalAnn k => GUnmarshalData (K1 c k) where
gunmarshalNode' _ = go unmarshalAnn where
go :: (Node -> MatchM k) -> Node -> MatchM (K1 c k a)
go = coerce
-- For anonymous leaf nodes
instance GUnmarshalData Par1 where
gunmarshalNode' _ = go unmarshalAnn where
go :: (Node -> MatchM a) -> Node -> MatchM (Par1 a)
go = coerce
instance Unmarshal t => GUnmarshalData (Rec1 t) where
gunmarshalNode' _ = go unmarshalNode where
go :: (Node -> MatchM (t a)) -> Node -> MatchM (Rec1 t a)
go = coerce
-- For product datatypes:
instance (GUnmarshalProduct f, GUnmarshalProduct g) => GUnmarshalData (f :*: g) where
gunmarshalNode' datatypeName node = asks cursor >>= flip getFields node >>= gunmarshalProductNode @(f :*: g) datatypeName node
-- | Generically unmarshal products
class GUnmarshalProduct f where
gunmarshalProductNode
:: UnmarshalAnn a
=> String
-> Node
-> Fields
-> MatchM (f a)
-- Product structure
instance (GUnmarshalProduct f, GUnmarshalProduct g) => GUnmarshalProduct (f :*: g) where
gunmarshalProductNode datatypeName node fields = (:*:)
<$> gunmarshalProductNode @f datatypeName node fields
<*> gunmarshalProductNode @g datatypeName node fields
-- Contents of product types (ie., the leaves of the product tree)
instance UnmarshalAnn k => GUnmarshalProduct (M1 S c (K1 i k)) where
gunmarshalProductNode _ node _ = go unmarshalAnn node where
go :: (Node -> MatchM k) -> Node -> MatchM (M1 S c (K1 i k) a)
go = coerce
instance GUnmarshalProduct (M1 S c Par1) where
gunmarshalProductNode _ node _ = go unmarshalAnn node where
go :: (Node -> MatchM a) -> Node -> MatchM (M1 S c Par1 a)
go = coerce
instance (UnmarshalField f, Unmarshal g, Selector c) => GUnmarshalProduct (M1 S c (f :.: g)) where
gunmarshalProductNode datatypeName _ = go (unmarshalField datatypeName fieldName . lookupField (FieldName fieldName)) where
go :: (Fields -> MatchM (f (g a))) -> Fields -> MatchM (M1 S c (f :.: g) a)
go = coerce
fieldName = selName @c undefined
instance (Unmarshal t, Selector c) => GUnmarshalProduct (M1 S c (Rec1 t)) where
gunmarshalProductNode datatypeName _ fields =
case lookupField (FieldName fieldName) fields of
[] -> liftIO . throwIO . UnmarshalError $ "type '" <> datatypeName <> "' expected a node '" <> fieldName <> "' but didn't get one"
[x] -> go unmarshalNode x where
go :: (Node -> MatchM (t a)) -> Node -> MatchM (M1 S c (Rec1 t) a)
go = coerce
_ -> liftIO . throwIO . UnmarshalError $ "type '" <> datatypeName <> "' expected a node but got multiple"
where
fieldName = selName @c undefined
class GHasAnn a t where
gann :: t a -> a
instance GHasAnn a f => GHasAnn a (M1 i c f) where
gann = gann . unM1
instance (GHasAnn a l, GHasAnn a r) => GHasAnn a (l :+: r) where
gann (L1 l) = gann l
gann (R1 r) = gann r
instance {-# OVERLAPPABLE #-} HasField "ann" (t a) a => GHasAnn a t where
gann = getField @"ann"