graphql-engine/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Protocol.hs
hasura-bot a886da2f21 server: address recent graphql-ws related bugs
GITHUB_PR_NUMBER: 7730
GITHUB_PR_URL: https://github.com/hasura/graphql-engine/pull/7730

PR-URL: https://github.com/hasura/graphql-engine-mono/pull/2685
Co-authored-by: Sameer Kolhar <6604943+kolharsam@users.noreply.github.com>
GitOrigin-RevId: 55bafd4eb1576e95803350f3ba9c7920a21de037
2021-11-04 12:40:02 +00:00

298 lines
9.2 KiB
Haskell

-- | This file contains types for both the websocket protocols (Apollo) and (graphql-ws)
-- | See Apollo: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
-- | See graphql-ws: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
module Hasura.GraphQL.Transport.WebSocket.Protocol where
import Control.Concurrent
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM
import Data.Aeson qualified as J
import Data.Aeson.TH qualified as J
import Data.ByteString.Lazy qualified as BL
import Data.Text (pack)
import Hasura.EncJSON
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.Prelude
-- NOTE: the `subProtocol` is decided based on the `Sec-WebSocket-Protocol`
-- header on every request sent to the server.
data WSSubProtocol = Apollo | GraphQLWS
deriving (Eq, Show)
-- NOTE: Please do not change them, as they're used for to identify the type of client
-- on every request that reaches the server. They are unique to each of the protocols.
showSubProtocol :: WSSubProtocol -> String
showSubProtocol subProtocol = case subProtocol of
-- REF: https://github.com/apollographql/subscriptions-transport-ws/blob/master/src/server.ts#L144
Apollo -> "graphql-ws"
-- REF: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#communication
GraphQLWS -> "graphql-transport-ws"
toWSSubProtocol :: String -> WSSubProtocol
toWSSubProtocol str = case str of
"graphql-transport-ws" -> GraphQLWS
_ -> Apollo
-- This is set by the client when it connects to the server
newtype OperationId = OperationId {unOperationId :: Text}
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString, Hashable)
data ServerMsgType
= -- specific to `Apollo` clients
SMT_GQL_CONNECTION_KEEP_ALIVE
| SMT_GQL_CONNECTION_ERROR
| SMT_GQL_DATA
| -- specific to `graphql-ws` clients
SMT_GQL_NEXT
| SMT_GQL_PING
| SMT_GQL_PONG
| -- common to clients of both protocols
SMT_GQL_CONNECTION_ACK
| SMT_GQL_ERROR
| SMT_GQL_COMPLETE
deriving (Eq)
instance Show ServerMsgType where
show = \case
-- specific to `Apollo` clients
SMT_GQL_CONNECTION_KEEP_ALIVE -> "ka"
SMT_GQL_CONNECTION_ERROR -> "connection_error"
SMT_GQL_DATA -> "data"
-- specific to `graphql-ws` clients
SMT_GQL_NEXT -> "next"
SMT_GQL_PING -> "ping"
SMT_GQL_PONG -> "pong"
-- common to clients of both protocols
SMT_GQL_CONNECTION_ACK -> "connection_ack"
SMT_GQL_ERROR -> "error"
SMT_GQL_COMPLETE -> "complete"
instance J.ToJSON ServerMsgType where
toJSON = J.toJSON . show
data ConnParams = ConnParams
{_cpHeaders :: Maybe (HashMap Text Text)}
deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''ConnParams)
data StartMsg = StartMsg
{ _smId :: !OperationId,
_smPayload :: !GQLReqUnparsed
}
deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''StartMsg)
data StopMsg = StopMsg
{ _stId :: OperationId
}
deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''StopMsg)
-- Specific to graphql-ws
data PingPongPayload = PingPongPayload
{ _smMessage :: !(Maybe Text) -- NOTE: this is not within the spec, but is specific to our usecase
}
deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''PingPongPayload)
-- Specific to graphql-ws
keepAliveMessage :: PingPongPayload
keepAliveMessage = PingPongPayload . Just . pack $ "keepalive"
-- Specific to graphql-ws
data SubscribeMsg = SubscribeMsg
{ _subId :: !OperationId,
_subPayload :: !GQLReqUnparsed
}
deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''SubscribeMsg)
data ClientMsg
= CMConnInit !(Maybe ConnParams)
| CMStart !StartMsg
| CMStop !StopMsg
| -- specific to apollo clients
CMConnTerm
| -- specific to graphql-ws clients
CMPing !(Maybe PingPongPayload)
| CMPong !(Maybe PingPongPayload)
deriving (Show, Eq)
instance J.FromJSON ClientMsg where
parseJSON = J.withObject "ClientMessage" $ \obj -> do
t <- obj J..: "type"
case (t :: String) of
"connection_init" -> CMConnInit <$> parsePayload obj
"start" -> CMStart <$> parseObj obj
"stop" -> CMStop <$> parseObj obj
"connection_terminate" -> pure CMConnTerm
-- graphql-ws specific message types
"complete" -> CMStop <$> parseObj obj
"subscribe" -> CMStart <$> parseObj obj
"ping" -> CMPing <$> parsePayload obj
"pong" -> CMPong <$> parsePayload obj
_ -> fail $ "unexpected type for ClientMessage: " <> t
where
parseObj o = J.parseJSON (J.Object o)
parsePayload py = py J..:? "payload"
data DataMsg = DataMsg
{ _dmId :: !OperationId,
_dmPayload :: !GQResponse
}
data ErrorMsg = ErrorMsg
{ _emId :: !OperationId,
_emPayload :: !J.Value
}
deriving (Show, Eq)
newtype CompletionMsg = CompletionMsg {unCompletionMsg :: OperationId}
deriving (Show, Eq)
instance J.FromJSON CompletionMsg where
parseJSON = J.withObject "CompletionMsg" $ \t ->
CompletionMsg <$> t J..: "id"
instance J.ToJSON CompletionMsg where
toJSON (CompletionMsg opId) = J.String $ tshow opId
newtype ConnErrMsg = ConnErrMsg {unConnErrMsg :: Text}
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString)
data ServerErrorMsg = ServerErrorMsg {unServerErrorMsg :: Text}
deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''ServerErrorMsg)
data ServerMsg
= SMConnAck
| SMConnKeepAlive
| SMConnErr !ConnErrMsg
| SMData !DataMsg
| SMErr !ErrorMsg
| SMComplete !CompletionMsg
| -- graphql-ws specific values
SMNext !DataMsg
| SMPing !(Maybe PingPongPayload)
| SMPong !(Maybe PingPongPayload)
-- | This is sent from the server to the client while closing the websocket
-- on encountering an error.
data ServerErrorCode
= ProtocolError1002
| GenericError4400 !String
| Unauthorized4401
| Forbidden4403
| ConnectionInitTimeout4408
| NonUniqueSubscription4409 !OperationId
| TooManyRequests4429
deriving stock (Show)
encodeServerErrorMsg :: ServerErrorCode -> BL.ByteString
encodeServerErrorMsg ecode = encJToLBS . encJFromJValue $ case ecode of
ProtocolError1002 -> packMsg "1002: Protocol Error"
GenericError4400 msg -> packMsg $ "4400: " <> msg
Unauthorized4401 -> packMsg "4401: Unauthorized"
Forbidden4403 -> packMsg "4403: Forbidden"
ConnectionInitTimeout4408 -> packMsg "4408: Connection initialisation timeout"
NonUniqueSubscription4409 opId -> packMsg $ "4409: Subscriber for " <> show opId <> " already exists"
TooManyRequests4429 -> packMsg "4429: Too many requests"
where
packMsg = ServerErrorMsg . pack
serverMsgType :: ServerMsg -> ServerMsgType
serverMsgType SMConnAck = SMT_GQL_CONNECTION_ACK
serverMsgType SMConnKeepAlive = SMT_GQL_CONNECTION_KEEP_ALIVE
serverMsgType (SMConnErr _) = SMT_GQL_CONNECTION_ERROR
serverMsgType (SMData _) = SMT_GQL_DATA
serverMsgType (SMErr _) = SMT_GQL_ERROR
serverMsgType (SMComplete _) = SMT_GQL_COMPLETE
serverMsgType (SMPing _) = SMT_GQL_PING
serverMsgType (SMPong _) = SMT_GQL_PONG
serverMsgType (SMNext _) = SMT_GQL_NEXT
encodeServerMsg :: ServerMsg -> BL.ByteString
encodeServerMsg msg =
encJToLBS $
encJFromAssocList $ case msg of
SMConnAck ->
[encTy SMT_GQL_CONNECTION_ACK]
SMConnKeepAlive ->
[encTy SMT_GQL_CONNECTION_KEEP_ALIVE]
SMConnErr connErr ->
[ encTy SMT_GQL_CONNECTION_ERROR,
("payload", encJFromJValue connErr)
]
SMData (DataMsg opId payload) ->
[ encTy SMT_GQL_DATA,
("id", encJFromJValue opId),
("payload", encodeGQResp payload)
]
SMErr (ErrorMsg opId payload) ->
[ encTy SMT_GQL_ERROR,
("id", encJFromJValue opId),
("payload", encJFromJValue payload)
]
SMComplete compMsg ->
[ encTy SMT_GQL_COMPLETE,
("id", encJFromJValue $ unCompletionMsg compMsg)
]
SMPing mPayload ->
encodePingPongPayload mPayload SMT_GQL_PING
SMPong mPayload ->
encodePingPongPayload mPayload SMT_GQL_PONG
SMNext (DataMsg opId payload) ->
[ encTy SMT_GQL_NEXT,
("id", encJFromJValue opId),
("payload", encodeGQResp payload)
]
where
encTy ty = ("type", encJFromJValue ty)
encodePingPongPayload mPayload msgType = case mPayload of
Just payload ->
[ encTy msgType,
("payload", encJFromJValue payload)
]
Nothing -> [encTy msgType]
-- This "timer" is necessary while initialising the connection
-- with the server. Also, this is specific to the GraphQL-WS protocol.
data WSConnInitTimerStatus = Running | Done
deriving stock (Show, Eq)
type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ())
waitForWSTimer :: WSConnInitTimer -> IO ()
waitForWSTimer (_, timer) = atomically $ readTMVar timer
stopWSTimer :: WSConnInitTimer -> IO ()
stopWSTimer (timerState, _) = atomically $ writeTVar timerState Done
getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState (timerState, _) = readTVarIO timerState
getNewWSTimer :: Seconds -> IO WSConnInitTimer
getNewWSTimer timeout = do
timerState <- newTVarIO Running
timer <- newEmptyTMVarIO
void $
forkIO $ do
sleep (seconds timeout)
atomically $ do
runTimerState <- readTVar timerState
case runTimerState of
Running -> do
-- time's up, we set status to "Done"
writeTVar timerState Done
putTMVar timer ()
Done -> pure ()
pure (timerState, timer)