mirror of
https://github.com/wasp-lang/wasp.git
synced 2024-10-26 17:10:02 +03:00
Wasp ai retry (#1288)
* Wasp AI now retries requests to chat GPT. * Improved retry. * Got tests for retry working. * one more test. * Improved retry tests a bit. * Finished tests. * fix
This commit is contained in:
parent
001068ca1a
commit
466e33d50c
@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
|
||||
module Wasp.AI.OpenAI.ChatGPT
|
||||
( queryChatGPT,
|
||||
@ -18,8 +19,11 @@ import qualified Data.Aeson as Aeson
|
||||
import Data.ByteString.UTF8 as BSU
|
||||
import Data.Text (Text)
|
||||
import GHC.Generics (Generic)
|
||||
import qualified Network.HTTP.Conduit as HTTP.C
|
||||
import qualified Network.HTTP.Simple as HTTP
|
||||
import UnliftIO.Exception (catch, throwIO)
|
||||
import Wasp.AI.OpenAI (OpenAIApiKey)
|
||||
import qualified Wasp.Util.IO.Retry as R
|
||||
|
||||
queryChatGPT :: OpenAIApiKey -> ChatGPTParams -> [ChatMessage] -> IO Text
|
||||
queryChatGPT apiKey params requestMessages = do
|
||||
@ -30,12 +34,14 @@ queryChatGPT apiKey params requestMessages = do
|
||||
]
|
||||
<> ["temperature" .= t | Just t <- pure $ _temperature params]
|
||||
request =
|
||||
HTTP.setRequestHeader "Authorization" [BSU.fromString $ "Bearer " <> apiKey] $
|
||||
HTTP.setRequestBodyJSON reqBodyJson $
|
||||
HTTP.parseRequest_ "POST https://api.openai.com/v1/chat/completions"
|
||||
-- 90 seconds should be more than enough for ChatGPT to generate an answer, or reach its own timeout.
|
||||
-- If it proves in the future that it might need more time, we can increase this number.
|
||||
HTTP.setRequestResponseTimeout (HTTP.C.responseTimeoutMicro $ secondsToMicroSeconds 90) $
|
||||
HTTP.setRequestHeader "Authorization" [BSU.fromString $ "Bearer " <> apiKey] $
|
||||
HTTP.setRequestBodyJSON reqBodyJson $
|
||||
HTTP.parseRequest_ "POST https://api.openai.com/v1/chat/completions"
|
||||
|
||||
-- TODO: Consider using httpJSONEither here, so I can handle errors better.
|
||||
response <- HTTP.httpJSON request
|
||||
response <- httpJSONWithRetry request
|
||||
|
||||
-- TODO: I should probably check status code here, confirm it is 200.
|
||||
let _responseStatusCode = HTTP.getResponseStatusCode response
|
||||
@ -43,6 +49,21 @@ queryChatGPT apiKey params requestMessages = do
|
||||
let (chatResponse :: ChatResponse) = HTTP.getResponseBody response
|
||||
|
||||
return $ content $ message $ head $ choices chatResponse
|
||||
where
|
||||
secondsToMicroSeconds :: Int -> Int
|
||||
secondsToMicroSeconds = (* 1000000)
|
||||
|
||||
httpJSONWithRetry request =
|
||||
-- NOTE: There is no strong reason for using linear pause here, or exactly 2 retries, we went
|
||||
-- with these settings as reasonable defaults.
|
||||
R.retry
|
||||
(R.linearPause $ fromIntegral $ secondsToMicroSeconds 10)
|
||||
2
|
||||
( (pure <$> HTTP.httpJSON request)
|
||||
`catch` (\e@(HTTP.HttpExceptionRequest _req HTTP.C.ResponseTimeout) -> pure $ Left e)
|
||||
`catch` (\e@(HTTP.HttpExceptionRequest _req HTTP.C.ConnectionTimeout) -> pure $ Left e)
|
||||
)
|
||||
>>= either throwIO pure
|
||||
|
||||
data ChatGPTParams = ChatGPTParams
|
||||
{ _model :: !Model,
|
||||
|
60
waspc/src/Wasp/Util/IO/Retry.hs
Normal file
60
waspc/src/Wasp/Util/IO/Retry.hs
Normal file
@ -0,0 +1,60 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Wasp.Util.IO.Retry
|
||||
( retry,
|
||||
constPause,
|
||||
linearPause,
|
||||
expPause,
|
||||
customPause,
|
||||
MonadRetry (..),
|
||||
PauseStrategy,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Numeric.Natural (Natural)
|
||||
import Prelude hiding (readFile, writeFile)
|
||||
|
||||
-- | Runs given action, and then if it fails, retries it, up to maxNumRetries.
|
||||
-- Uses provided pauseStrategy to calculate pause between tries.
|
||||
retry :: forall m e a. (MonadRetry m) => PauseStrategy -> Natural -> m (Either e a) -> m (Either e a)
|
||||
retry (PauseStrategy calcPause) maxNumRetries action = go 0
|
||||
where
|
||||
maxNumTries :: Natural
|
||||
maxNumTries = maxNumRetries + 1
|
||||
|
||||
go :: Natural -> m (Either e a)
|
||||
go numFailedTries =
|
||||
action >>= \case
|
||||
Right result -> pure $ Right result
|
||||
Left e ->
|
||||
let numFailedTries' = numFailedTries + 1
|
||||
in if numFailedTries' < maxNumTries
|
||||
then do
|
||||
rThreadDelay $ fromIntegral $ calcPause numFailedTries'
|
||||
go numFailedTries'
|
||||
else pure $ Left e
|
||||
|
||||
class (Monad m) => MonadRetry m where
|
||||
rThreadDelay :: Int -> m ()
|
||||
|
||||
instance MonadRetry IO where
|
||||
rThreadDelay = threadDelay
|
||||
|
||||
newtype PauseStrategy = PauseStrategy (NumFailedTries -> Microseconds)
|
||||
|
||||
type Microseconds = Natural
|
||||
|
||||
type NumFailedTries = Natural
|
||||
|
||||
constPause :: Microseconds -> PauseStrategy
|
||||
constPause basePause = PauseStrategy (const basePause)
|
||||
|
||||
linearPause :: Microseconds -> PauseStrategy
|
||||
linearPause basePause = PauseStrategy (* basePause)
|
||||
|
||||
expPause :: Microseconds -> PauseStrategy
|
||||
expPause basePause = PauseStrategy $ \i -> basePause * 2 ^ (i - 1)
|
||||
|
||||
customPause :: (NumFailedTries -> Microseconds) -> PauseStrategy
|
||||
customPause = PauseStrategy
|
64
waspc/test/Util/IO/RetryTest.hs
Normal file
64
waspc/test/Util/IO/RetryTest.hs
Normal file
@ -0,0 +1,64 @@
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Util.IO.RetryTest where
|
||||
|
||||
import Control.Monad (forM_)
|
||||
import Control.Monad.State (MonadState (get), State, modify, runState)
|
||||
import Numeric.Natural (Natural)
|
||||
import Test.Tasty.Hspec (Spec, describe, it, shouldBe)
|
||||
import qualified Wasp.Util.IO.Retry as R
|
||||
|
||||
spec_RetryTest :: Spec
|
||||
spec_RetryTest = do
|
||||
describe "retry" $ do
|
||||
describe "when action succeeds on the first try" $ do
|
||||
it "runs action only once" $ do
|
||||
runMockRetry (R.constPause 42) 2 (mockAction (NumFails 0))
|
||||
`shouldBe` (Right (), [ActionCall])
|
||||
describe "when action fails 2 times and then succeeds" $ do
|
||||
let action = mockAction (NumFails 2)
|
||||
describe "and maxNumRetries >= 2" $ do
|
||||
it "will run it 3 times and end with success" $ do
|
||||
forM_ [2, 3, 4, 10] $ \maxNumRetries ->
|
||||
runMockRetry (R.constPause 42) maxNumRetries action
|
||||
`shouldBe` (Right (), [ActionCall, ThreadDelayCall 42, ActionCall, ThreadDelayCall 42, ActionCall])
|
||||
describe "and maxNumRetries < 2" $ do
|
||||
it "will run it (maxNumRetries + 1) times and end with failure" $ do
|
||||
runMockRetry (R.constPause 42) 0 action
|
||||
`shouldBe` (Left (), [ActionCall])
|
||||
runMockRetry (R.constPause 42) 1 action
|
||||
`shouldBe` (Left (), [ActionCall, ThreadDelayCall 42, ActionCall])
|
||||
describe "determines pauses according to provided pause strategy" $ do
|
||||
let action = mockAction (NumFails 3)
|
||||
let testPause = \pauseStrategy _expectedPauses@(p1, p2, p3) ->
|
||||
snd (runMockRetry pauseStrategy 5 action)
|
||||
`shouldBe` [ActionCall, ThreadDelayCall p1, ActionCall, ThreadDelayCall p2, ActionCall, ThreadDelayCall p3, ActionCall]
|
||||
it "for constPause" $ testPause (R.constPause 10) (10, 10, 10)
|
||||
it "for linearPause" $ testPause (R.linearPause 10) (10, 20, 30)
|
||||
it "for expPause" $ testPause (R.expPause 10) (10, 20, 40)
|
||||
it "for customPause" $ testPause (R.customPause (^ (2 :: Int))) (1, 4, 9)
|
||||
|
||||
runMockRetry :: R.PauseStrategy -> Natural -> MockAction -> (Either () (), [Event])
|
||||
runMockRetry pause maxNumRetries action = runState (R.retry pause maxNumRetries action) []
|
||||
|
||||
type MockAction = MockRetryMonad (Either () ())
|
||||
|
||||
mockAction :: NumFails -> MockAction
|
||||
mockAction (NumFails numFails) = do
|
||||
events <- get
|
||||
let result =
|
||||
if length (filter (== ActionCall) events) >= numFails
|
||||
then Right ()
|
||||
else Left ()
|
||||
modify (++ [ActionCall])
|
||||
return result
|
||||
|
||||
newtype NumFails = NumFails Int
|
||||
|
||||
data Event = ThreadDelayCall Int | ActionCall
|
||||
deriving (Show, Eq)
|
||||
|
||||
type MockRetryMonad = State [Event]
|
||||
|
||||
instance R.MonadRetry MockRetryMonad where
|
||||
rThreadDelay microseconds = modify (++ [ThreadDelayCall microseconds])
|
@ -344,6 +344,7 @@ library
|
||||
Wasp.Util.Control.Monad
|
||||
Wasp.Util.Fib
|
||||
Wasp.Util.IO
|
||||
Wasp.Util.IO.Retry
|
||||
Wasp.Util.Terminal
|
||||
Wasp.Util.FilePath
|
||||
Wasp.WaspignoreFile
|
||||
@ -555,6 +556,7 @@ test-suite waspc-test
|
||||
UtilTest
|
||||
Util.Diff
|
||||
Util.FilePathTest
|
||||
Util.IO.RetryTest
|
||||
SemanticVersionTest
|
||||
WaspignoreFileTest
|
||||
Paths_waspc
|
||||
|
Loading…
Reference in New Issue
Block a user