Wasp ai retry (#1288)

* Wasp AI now retries requests to chat GPT.

* Improved retry.

* Got tests for retry working.

* one more test.

* Improved retry tests a bit.

* Finished tests.

* fix
This commit is contained in:
Martin Šošić 2023-06-26 17:48:08 +02:00 committed by GitHub
parent 001068ca1a
commit 466e33d50c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 152 additions and 5 deletions

View File

@ -1,4 +1,5 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ViewPatterns #-}
module Wasp.AI.OpenAI.ChatGPT
( queryChatGPT,
@ -18,8 +19,11 @@ import qualified Data.Aeson as Aeson
import Data.ByteString.UTF8 as BSU
import Data.Text (Text)
import GHC.Generics (Generic)
import qualified Network.HTTP.Conduit as HTTP.C
import qualified Network.HTTP.Simple as HTTP
import UnliftIO.Exception (catch, throwIO)
import Wasp.AI.OpenAI (OpenAIApiKey)
import qualified Wasp.Util.IO.Retry as R
queryChatGPT :: OpenAIApiKey -> ChatGPTParams -> [ChatMessage] -> IO Text
queryChatGPT apiKey params requestMessages = do
@ -30,12 +34,14 @@ queryChatGPT apiKey params requestMessages = do
]
<> ["temperature" .= t | Just t <- pure $ _temperature params]
request =
HTTP.setRequestHeader "Authorization" [BSU.fromString $ "Bearer " <> apiKey] $
HTTP.setRequestBodyJSON reqBodyJson $
HTTP.parseRequest_ "POST https://api.openai.com/v1/chat/completions"
-- 90 seconds should be more than enough for ChatGPT to generate an answer, or reach its own timeout.
-- If it proves in the future that it might need more time, we can increase this number.
HTTP.setRequestResponseTimeout (HTTP.C.responseTimeoutMicro $ secondsToMicroSeconds 90) $
HTTP.setRequestHeader "Authorization" [BSU.fromString $ "Bearer " <> apiKey] $
HTTP.setRequestBodyJSON reqBodyJson $
HTTP.parseRequest_ "POST https://api.openai.com/v1/chat/completions"
-- TODO: Consider using httpJSONEither here, so I can handle errors better.
response <- HTTP.httpJSON request
response <- httpJSONWithRetry request
-- TODO: I should probably check status code here, confirm it is 200.
let _responseStatusCode = HTTP.getResponseStatusCode response
@ -43,6 +49,21 @@ queryChatGPT apiKey params requestMessages = do
let (chatResponse :: ChatResponse) = HTTP.getResponseBody response
return $ content $ message $ head $ choices chatResponse
where
secondsToMicroSeconds :: Int -> Int
secondsToMicroSeconds = (* 1000000)
httpJSONWithRetry request =
-- NOTE: There is no strong reason for using linear pause here, or exactly 2 retries, we went
-- with these settings as reasonable defaults.
R.retry
(R.linearPause $ fromIntegral $ secondsToMicroSeconds 10)
2
( (pure <$> HTTP.httpJSON request)
`catch` (\e@(HTTP.HttpExceptionRequest _req HTTP.C.ResponseTimeout) -> pure $ Left e)
`catch` (\e@(HTTP.HttpExceptionRequest _req HTTP.C.ConnectionTimeout) -> pure $ Left e)
)
>>= either throwIO pure
data ChatGPTParams = ChatGPTParams
{ _model :: !Model,

View File

@ -0,0 +1,60 @@
{-# LANGUAGE ScopedTypeVariables #-}
module Wasp.Util.IO.Retry
( retry,
constPause,
linearPause,
expPause,
customPause,
MonadRetry (..),
PauseStrategy,
)
where
import Control.Concurrent (threadDelay)
import Numeric.Natural (Natural)
import Prelude hiding (readFile, writeFile)
-- | Runs given action, and then if it fails, retries it, up to maxNumRetries.
-- Uses provided pauseStrategy to calculate pause between tries.
retry :: forall m e a. (MonadRetry m) => PauseStrategy -> Natural -> m (Either e a) -> m (Either e a)
retry (PauseStrategy calcPause) maxNumRetries action = go 0
where
maxNumTries :: Natural
maxNumTries = maxNumRetries + 1
go :: Natural -> m (Either e a)
go numFailedTries =
action >>= \case
Right result -> pure $ Right result
Left e ->
let numFailedTries' = numFailedTries + 1
in if numFailedTries' < maxNumTries
then do
rThreadDelay $ fromIntegral $ calcPause numFailedTries'
go numFailedTries'
else pure $ Left e
class (Monad m) => MonadRetry m where
rThreadDelay :: Int -> m ()
instance MonadRetry IO where
rThreadDelay = threadDelay
newtype PauseStrategy = PauseStrategy (NumFailedTries -> Microseconds)
type Microseconds = Natural
type NumFailedTries = Natural
constPause :: Microseconds -> PauseStrategy
constPause basePause = PauseStrategy (const basePause)
linearPause :: Microseconds -> PauseStrategy
linearPause basePause = PauseStrategy (* basePause)
expPause :: Microseconds -> PauseStrategy
expPause basePause = PauseStrategy $ \i -> basePause * 2 ^ (i - 1)
customPause :: (NumFailedTries -> Microseconds) -> PauseStrategy
customPause = PauseStrategy

View File

@ -0,0 +1,64 @@
{-# LANGUAGE FlexibleInstances #-}
module Util.IO.RetryTest where
import Control.Monad (forM_)
import Control.Monad.State (MonadState (get), State, modify, runState)
import Numeric.Natural (Natural)
import Test.Tasty.Hspec (Spec, describe, it, shouldBe)
import qualified Wasp.Util.IO.Retry as R
spec_RetryTest :: Spec
spec_RetryTest = do
describe "retry" $ do
describe "when action succeeds on the first try" $ do
it "runs action only once" $ do
runMockRetry (R.constPause 42) 2 (mockAction (NumFails 0))
`shouldBe` (Right (), [ActionCall])
describe "when action fails 2 times and then succeeds" $ do
let action = mockAction (NumFails 2)
describe "and maxNumRetries >= 2" $ do
it "will run it 3 times and end with success" $ do
forM_ [2, 3, 4, 10] $ \maxNumRetries ->
runMockRetry (R.constPause 42) maxNumRetries action
`shouldBe` (Right (), [ActionCall, ThreadDelayCall 42, ActionCall, ThreadDelayCall 42, ActionCall])
describe "and maxNumRetries < 2" $ do
it "will run it (maxNumRetries + 1) times and end with failure" $ do
runMockRetry (R.constPause 42) 0 action
`shouldBe` (Left (), [ActionCall])
runMockRetry (R.constPause 42) 1 action
`shouldBe` (Left (), [ActionCall, ThreadDelayCall 42, ActionCall])
describe "determines pauses according to provided pause strategy" $ do
let action = mockAction (NumFails 3)
let testPause = \pauseStrategy _expectedPauses@(p1, p2, p3) ->
snd (runMockRetry pauseStrategy 5 action)
`shouldBe` [ActionCall, ThreadDelayCall p1, ActionCall, ThreadDelayCall p2, ActionCall, ThreadDelayCall p3, ActionCall]
it "for constPause" $ testPause (R.constPause 10) (10, 10, 10)
it "for linearPause" $ testPause (R.linearPause 10) (10, 20, 30)
it "for expPause" $ testPause (R.expPause 10) (10, 20, 40)
it "for customPause" $ testPause (R.customPause (^ (2 :: Int))) (1, 4, 9)
runMockRetry :: R.PauseStrategy -> Natural -> MockAction -> (Either () (), [Event])
runMockRetry pause maxNumRetries action = runState (R.retry pause maxNumRetries action) []
type MockAction = MockRetryMonad (Either () ())
mockAction :: NumFails -> MockAction
mockAction (NumFails numFails) = do
events <- get
let result =
if length (filter (== ActionCall) events) >= numFails
then Right ()
else Left ()
modify (++ [ActionCall])
return result
newtype NumFails = NumFails Int
data Event = ThreadDelayCall Int | ActionCall
deriving (Show, Eq)
type MockRetryMonad = State [Event]
instance R.MonadRetry MockRetryMonad where
rThreadDelay microseconds = modify (++ [ThreadDelayCall microseconds])

View File

@ -344,6 +344,7 @@ library
Wasp.Util.Control.Monad
Wasp.Util.Fib
Wasp.Util.IO
Wasp.Util.IO.Retry
Wasp.Util.Terminal
Wasp.Util.FilePath
Wasp.WaspignoreFile
@ -555,6 +556,7 @@ test-suite waspc-test
UtilTest
Util.Diff
Util.FilePathTest
Util.IO.RetryTest
SemanticVersionTest
WaspignoreFileTest
Paths_waspc