Implemented graceful shutdown for websockets (#2827)

This commit is contained in:
José Lorenzo Rodríguez 2019-09-09 22:26:04 +02:00 committed by Alexis King
parent eea1e33d19
commit 5609fba393
4 changed files with 148 additions and 53 deletions

View File

@ -30,8 +30,9 @@ import Hasura.Prelude
import Hasura.RQL.DDL.Metadata (fetchMetadata)
import Hasura.RQL.Types (SQLGenCtx (..), SchemaCache (..),
adminUserInfo, emptySchemaCache)
import Hasura.Server.App (SchemaCacheRef (..), getSCFromRef,
logInconsObjs, mkWaiApp)
import Hasura.Server.App (HasuraApp(..), SchemaCacheRef (..),
getSCFromRef, logInconsObjs,
mkWaiApp)
import Hasura.Server.Auth
import Hasura.Server.CheckUpdates (checkForUpdates)
import Hasura.Server.Init
@ -150,7 +151,7 @@ main = do
-- safe init catalog
dbId <- initialise pool sqlGenCtx logger httpManager
(app, cacheRef, cacheInitTime) <-
HasuraApp app cacheRef cacheInitTime shutdownApp <-
mkWaiApp isoL loggerCtx sqlGenCtx enableAL pool ci httpManager am
corsCfg enableConsole consoleAssetsDir enableTelemetry
instanceId enabledAPIs lqOpts
@ -166,7 +167,7 @@ main = do
let warpSettings = Warp.setPort port
. Warp.setHost host
. Warp.setGracefulShutdownTimeout (Just 30) -- 30s graceful shutdown
. Warp.setInstallShutdownHandler (shutdownHandler logger)
. Warp.setInstallShutdownHandler (shutdownHandler logger shutdownApp)
$ Warp.defaultSettings
maxEvThrds <- getFromEnv defaultMaxEventThreads "HASURA_GRAPHQL_EVENTS_HTTP_POOL_SIZE"
@ -283,10 +284,18 @@ main = do
-- requests is already implemented in Warp, and is triggered by invoking the 'closeSocket' callback.
-- We only catch the SIGTERM signal once, that is, if the user hits CTRL-C once again, we terminate
-- the process immediately.
shutdownHandler :: Logger -> IO () -> IO ()
shutdownHandler (Logger logger) closeSocket =
void $ Signals.installHandler Signals.sigTERM (Signals.CatchOnce $ closeSocket >> logShutdown) Nothing
shutdownHandler :: Logger -> IO () -> IO () -> IO ()
shutdownHandler (Logger logger) shutdownApp closeSocket =
void $ Signals.installHandler
Signals.sigTERM
(Signals.CatchOnce shutdownSequence)
Nothing
where
shutdownSequence = do
closeSocket
shutdownApp
logShutdown
logShutdown = logger $
mkGenericStrLog LevelInfo "server" "gracefully shutting down server"

View File

@ -3,6 +3,7 @@
module Hasura.GraphQL.Transport.WebSocket
( createWSServerApp
, createWSServerEnv
, stopWSServerApp
, WSServerEnv
) where
@ -532,3 +533,6 @@ createWSServerApp authMode serverEnv =
(onConn (_wseLogger serverEnv) (_wseCorsPolicy serverEnv))
(onMessage authMode serverEnv)
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)
stopWSServerApp :: WSServerEnv -> IO ()
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)

View File

@ -19,6 +19,7 @@ module Hasura.GraphQL.Transport.WebSocket.Server
, createWSServer
, closeAll
, createServerApp
, shutdown
) where
import qualified Control.Concurrent.Async as A
@ -33,6 +34,7 @@ import qualified Data.UUID.V4 as UUID
import qualified ListT
import qualified Network.WebSockets as WS
import qualified StmContainers.Map as STMMap
import Data.Word (Word16)
import Control.Exception (try)
import qualified Hasura.Logging as L
@ -90,9 +92,17 @@ getWSId :: WSConn a -> WSId
getWSId = _wcConnId
closeConn :: WSConn a -> BL.ByteString -> IO ()
closeConn wsConn bs = do
closeConn wsConn bs = closeConnWithCode wsConn 1000 bs -- 1000 is "normal close"
-- | Closes a connection with code 1012, which means "Server is restarting"
-- good clients will implement a retry logic with a backoff of a few seconds
forceConnReconnect :: WSConn a -> BL.ByteString -> IO ()
forceConnReconnect wsConn bs = closeConnWithCode wsConn 1012 bs
closeConnWithCode :: WSConn a -> Word16 -> BL.ByteString -> IO ()
closeConnWithCode wsConn code bs = do
(L.unLogger . _wcLogger) wsConn $ WSLog (_wcConnId wsConn) $ ECloseSent $ TBS.fromLBS bs
WS.sendClose (_wcConnRaw wsConn) bs
WS.sendCloseCode (_wcConnRaw wsConn) code bs
-- writes to a queue instead of the raw connection
-- so that sendMsg doesn't block
@ -102,23 +112,48 @@ sendMsg wsConn msg =
type ConnMap a = STMMap.Map WSId (WSConn a)
data ServerStatus a
= AcceptingConns !(ConnMap a)
| ShuttingDown
data WSServer a
= WSServer
{ _wssLogger :: L.Logger
, _wssConnMap :: ConnMap a
{ _wssLogger :: L.Logger
, _wssStatus :: !(STM.TVar (ServerStatus a))
}
createWSServer :: L.Logger -> STM.STM (WSServer a)
createWSServer logger = WSServer logger <$> STMMap.new
createWSServer logger = do
connMap <- STMMap.new
serverStatus <- STM.newTVar (AcceptingConns connMap)
return $ WSServer logger serverStatus
closeAll :: WSServer a -> BL.ByteString -> IO ()
closeAll (WSServer (L.Logger writeLog) connMap) msg = do
closeAll (WSServer (L.Logger writeLog) serverStatus) msg = do
writeLog $ L.debugT "closing all connections"
conns <- STM.atomically $ do
conns <- ListT.toList $ STMMap.listT connMap
STMMap.reset connMap
return conns
void $ A.mapConcurrently (flip closeConn msg . snd) conns
conns <- STM.atomically $ flushConnMap serverStatus
closeAllWith (flip closeConn) msg conns
closeAllWith
:: (BL.ByteString -> WSConn a -> IO ())
-> BL.ByteString
-> [(WSId, WSConn a)]
-> IO ()
closeAllWith closer msg conns =
void $ A.mapConcurrently (closer msg . snd) conns
-- | Resets the current connections map to an empty one if the server is
-- running and returns the list of connections that were in the map
-- before flushing it.
flushConnMap :: STM.TVar (ServerStatus a) -> STM.STM [(WSId, WSConn a)]
flushConnMap serverStatus = do
status <- STM.readTVar serverStatus
case status of
AcceptingConns connMap -> do
conns <- ListT.toList $ STMMap.listT connMap
STMMap.reset connMap
return conns
ShuttingDown -> return []
data AcceptWith a
= AcceptWith
@ -145,15 +180,28 @@ createServerApp
-- user provided handlers
-> WSHandlers a
-- aka WS.ServerApp
-> WS.PendingConnection -> IO ()
createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pendingConn = do
-> WS.PendingConnection
-> IO ()
createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers pendingConn = do
wsId <- WSId <$> UUID.nextRandom
writeLog $ WSLog wsId EConnectionRequest
let reqHead = WS.pendingRequest pendingConn
onConnRes <- _hOnConn wsHandlers wsId reqHead
either (onReject wsId) (onAccept wsId) onConnRes
status <- STM.readTVarIO serverStatus
case status of
AcceptingConns _ -> do
let reqHead = WS.pendingRequest pendingConn
onConnRes <- _hOnConn wsHandlers wsId reqHead
either (onReject wsId) (onAccept wsId) onConnRes
ShuttingDown ->
onReject wsId shuttingDownReject
where
shuttingDownReject =
WS.RejectRequest 503
"Service Unavailable"
[("Retry-After", "0")]
"Server is shutting down"
onReject wsId rejectRequest = do
WS.rejectRequestWith pendingConn rejectRequest
writeLog $ WSLog wsId ERejected
@ -161,39 +209,62 @@ createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pending
onAccept wsId (AcceptWith a acceptWithParams keepAliveM onJwtExpiryM) = do
conn <- WS.acceptRequestWith pendingConn acceptWithParams
writeLog $ WSLog wsId EAccepted
sendQ <- STM.newTQueueIO
let wsConn = WSConn wsId logger conn sendQ a
STM.atomically $ STMMap.insert wsConn wsId connMap
rcvRef <- A.async $ forever $ do
msg <- WS.receiveData conn
writeLog $ WSLog wsId $ EMessageReceived $ TBS.fromLBS msg
_hOnMessage wsHandlers wsConn msg
status <- STM.atomically $ do
status <- STM.readTVar serverStatus
case status of
ShuttingDown -> pure ()
AcceptingConns connMap -> STMMap.insert wsConn wsId connMap
return status
sendRef <- A.async $ forever $ do
msg <- STM.atomically $ STM.readTQueue sendQ
WS.sendTextData conn msg
writeLog $ WSLog wsId $ EMessageSent $ TBS.fromLBS msg
case status of
ShuttingDown -> do
-- Bad luck, we were in the process of shutting the server down but a new
-- connection was accepted. Let's just close it politely
forceConnReconnect wsConn "shutting server down"
_hOnClose wsHandlers wsConn
keepAliveRefM <- forM keepAliveM $ \action -> A.async $ action wsConn
onJwtExpiryRefM <- forM onJwtExpiryM $ \action -> A.async $ action wsConn
AcceptingConns connMap -> do
rcvRef <- A.async $ forever $ do
msg <- WS.receiveData conn
writeLog $ WSLog wsId $ EMessageReceived $ TBS.fromLBS msg
_hOnMessage wsHandlers wsConn msg
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = catMaybes [keepAliveRefM, onJwtExpiryRefM]
<> [rcvRef, sendRef]
res <- try $ A.waitAnyCancel waitOnRefs
sendRef <- A.async $ forever $ do
msg <- STM.atomically $ STM.readTQueue sendQ
WS.sendTextData conn msg
writeLog $ WSLog wsId $ EMessageSent $ TBS.fromLBS msg
case res of
Left ( _ :: WS.ConnectionException) -> do
writeLog $ WSLog (_wcConnId wsConn) ECloseReceived
onConnClose wsConn
-- this will happen when jwt is expired
Right _ -> do
writeLog $ WSLog (_wcConnId wsConn) EJwtExpired
onConnClose wsConn
keepAliveRefM <- forM keepAliveM $ \action -> A.async $ action wsConn
onJwtExpiryRefM <- forM onJwtExpiryM $ \action -> A.async $ action wsConn
onConnClose wsConn = do
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = catMaybes [keepAliveRefM, onJwtExpiryRefM]
<> [rcvRef, sendRef]
res <- try $ A.waitAnyCancel waitOnRefs
case res of
Left ( _ :: WS.ConnectionException) -> do
writeLog $ WSLog (_wcConnId wsConn) ECloseReceived
onConnClose connMap wsConn
-- this will happen when jwt is expired
Right _ -> do
writeLog $ WSLog (_wcConnId wsConn) EJwtExpired
onConnClose connMap wsConn
onConnClose connMap wsConn = do
STM.atomically $ STMMap.delete (_wcConnId wsConn) connMap
_hOnClose wsHandlers wsConn
writeLog $ WSLog (_wcConnId wsConn) EClosed
shutdown :: WSServer a -> IO ()
shutdown (WSServer (L.Logger writeLog) serverStatus) = do
writeLog $ L.debugT "Shutting websockets server down"
conns <- STM.atomically $ do
conns <- flushConnMap serverStatus
STM.writeTVar serverStatus ShuttingDown
return conns
closeAllWith (flip forceConnReconnect) "shutting server down" conns

View File

@ -435,6 +435,14 @@ initErrExit e = do
<> T.unpack (qeError e)
exitFailure
data HasuraApp
= HasuraApp
{ _hapApplication :: !Wai.Application
, _hapSchemaRef :: !SchemaCacheRef
, _hapCacheBuildTime :: !(Maybe UTCTime)
, _hapShutdown :: !(IO ())
}
mkWaiApp
:: Q.TxIsolation
-> L.LoggerCtx
@ -451,7 +459,7 @@ mkWaiApp
-> InstanceId
-> S.HashSet API
-> EL.LQOpts
-> IO (Wai.Application, SchemaCacheRef, Maybe UTCTime)
-> IO HasuraApp
mkWaiApp isoLevel loggerCtx sqlGenCtx enableAL pool ci httpManager mode corsCfg
enableConsole consoleAssetsDir enableTelemetry instanceId apis lqOpts = do
@ -494,10 +502,13 @@ mkWaiApp isoLevel loggerCtx sqlGenCtx enableAL pool ci httpManager mode corsCfg
consoleAssetsDir enableTelemetry
let wsServerApp = WS.createWSServerApp mode wsServerEnv
return ( WS.websocketsOr WS.defaultConnectionOptions wsServerApp spockApp
, schemaCacheRef
, cacheBuiltTime
)
stopWSServer = WS.stopWSServerApp wsServerEnv
return $ HasuraApp
(WS.websocketsOr WS.defaultConnectionOptions wsServerApp spockApp)
schemaCacheRef
cacheBuiltTime
stopWSServer
where
getTimeMs :: IO Int64
getTimeMs = (round . (* 1000)) `fmap` getPOSIXTime