Support more GPT models in Mage + allow customizing both of them (#1598)

* Enable new OpenAI models for MAGE

* Add option to specify Plan GPT model from CLI

* Fix error in parsing planGptModel

* Remove check for GPT-4 availability

* Refactor initialization of GPT params

* Add GPT 3.5 Turbo 0613

* Fix GPT versions to exact values

Generic OpenAI models, such as 'GPT-4' and 'GPT-3_5_turbo',
point to the 'latest' of the respective OpenAI models.
This means that the exact models used can change unexpectedly.
This would not normally be an issue, but the newer models
tend to exhibit a performance drop.
Therefore, to future-proof our application, we fix the
model versions to prevent any unexpected model changes.

* Rename base->coding, plan->planning

* Remove unused _useGpt3IfGpt4NotAvailable field

* Remove confusing helper function

* Fix typo Planing -> Planning
This commit is contained in:
Janko Vidaković 2023-12-13 17:11:08 +01:00 committed by GitHub
parent 548b5a42bd
commit 34f6539590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 69 additions and 61 deletions

View File

@ -61,8 +61,7 @@ createNewProjectOnDisk openAIApiKey waspProjectDir appName appDescription projec
CA.CodeAgentConfig
{ CA._openAIApiKey = openAIApiKey,
CA._writeFile = writeFileToDisk,
CA._writeLog = forwardLogToStdout,
CA._useGpt3IfGpt4NotAvailable = True
CA._writeLog = forwardLogToStdout
}
writeFileToDisk path content = do
@ -93,8 +92,7 @@ createNewProjectNonInteractiveToStdout projectName appDescription projectConfigJ
CA.CodeAgentConfig
{ CA._openAIApiKey = openAIApiKey,
CA._writeFile = writeFileToStdoutWithDelimiters,
CA._writeLog = writeLogToStdoutWithDelimiters,
CA._useGpt3IfGpt4NotAvailable = True
CA._writeLog = writeLogToStdoutWithDelimiters
}
liftIO $ generateNewProject codeAgentConfig appName appDescription projectConfig

View File

@ -12,7 +12,6 @@ module Wasp.AI.CodeAgent
queryChatGPT,
getTotalTokensUsage,
getOpenAIApiKey,
checkIfGpt4IsAvailable,
)
where
@ -42,8 +41,7 @@ newtype CodeAgent a = CodeAgent {_unCodeAgent :: ReaderT CodeAgentConfig (StateT
data CodeAgentConfig = CodeAgentConfig
{ _openAIApiKey :: !OpenAIApiKey,
_writeFile :: !(FilePath -> Text -> IO ()), -- TODO: Use StrongPath? Not clear which kind of path is it, rel, abs, ... .
_writeLog :: !(Text -> IO ()),
_useGpt3IfGpt4NotAvailable :: !Bool
_writeLog :: !(Text -> IO ())
}
instance MonadRetry CodeAgent where
@ -101,18 +99,8 @@ getAllFiles = gets $ H.toList . _files
queryChatGPT :: ChatGPTParams -> [ChatMessage] -> CodeAgent Text
queryChatGPT params messages = do
params' <- do
useGpt3IfGpt4NotAvailable <- asks _useGpt3IfGpt4NotAvailable
if ChatGPT._model params == ChatGPT.GPT_4 && useGpt3IfGpt4NotAvailable
then do
isAvailable <- checkIfGpt4IsAvailable
if not isAvailable
then return $ params {ChatGPT._model = ChatGPT.GPT_3_5_turbo_16k}
else return params
else return params
key <- asks _openAIApiKey
chatResponse <- queryChatGPTWithRetry key params' messages
chatResponse <- queryChatGPTWithRetry key params messages
modify $ \s -> s {_usage = _usage s <> [ChatGPT.usage chatResponse]}
return $ ChatGPT.getChatResponseContent chatResponse
where
@ -138,16 +126,6 @@ queryChatGPT params messages = do
getOpenAIApiKey :: CodeAgent OpenAIApiKey
getOpenAIApiKey = asks _openAIApiKey
checkIfGpt4IsAvailable :: CodeAgent Bool
checkIfGpt4IsAvailable = do
gets _isGpt4Available >>= \case
Just isAvailable -> pure isAvailable
Nothing -> do
key <- asks _openAIApiKey
isAvailable <- liftIO $ ChatGPT.checkIfGpt4IsAvailable key
modify $ \s -> s {_isGpt4Available = Just isAvailable}
return isAvailable
type NumTokens = Int
-- | Returns total tokens usage: (<num_prompt_tokens>, <num_completion_tokens>).

View File

@ -7,9 +7,10 @@ module Wasp.AI.GenerateNewProject.Common
getProjectPrimaryColor,
emptyNewProjectConfig,
queryChatGPTForJSON,
defaultChatGPTParams,
defaultChatGPTParamsForFixing,
writeToWaspFileEnd,
planningChatGPTParams,
codingChatGPTParams,
fixingChatGPTParams,
)
where
@ -33,7 +34,8 @@ data NewProjectConfig = NewProjectConfig
{ projectAuth :: !(Maybe AuthProvider),
-- One of the Tailwind color names: https://tailwindcss.com/docs/customizing-colors
projectPrimaryColor :: !(Maybe String),
projectDefaultGptModel :: !(Maybe GPT.Model),
projectCodingGptModel :: !(Maybe GPT.Model),
projectPlanningGptModel :: !(Maybe GPT.Model),
projectDefaultGptTemperature :: !(Maybe Float)
}
deriving (Show)
@ -42,13 +44,15 @@ instance Aeson.FromJSON NewProjectConfig where
parseJSON = withObject "NewProjectConfig" $ \obj -> do
auth <- obj .:? "auth"
primaryColor <- obj .:? "primaryColor"
defaultGptModel <- obj .:? "defaultGptModel"
codingGptModel <- obj .:? "codingGptModel"
planningGptModel <- obj .:? "planningGptModel"
defaultGptTemperature <- obj .:? "defaultGptTemperature"
return
( NewProjectConfig
{ projectAuth = auth,
projectPrimaryColor = primaryColor,
projectDefaultGptModel = defaultGptModel,
projectCodingGptModel = codingGptModel,
projectPlanningGptModel = planningGptModel,
projectDefaultGptTemperature = defaultGptTemperature
}
)
@ -58,7 +62,8 @@ emptyNewProjectConfig =
NewProjectConfig
{ projectAuth = Nothing,
projectPrimaryColor = Nothing,
projectDefaultGptModel = Nothing,
projectCodingGptModel = Nothing,
projectPlanningGptModel = Nothing,
projectDefaultGptTemperature = Nothing
}
@ -133,17 +138,26 @@ queryChatGPTForJSON chatGPTParams initChatMsgs = doQueryForJSON 0 0 initChatMsgs
maxNumFailuresPerRunBeforeGivingUpOnARun = 2
maxNumFailedRunsBeforeGivingUpCompletely = 2
defaultChatGPTParams :: NewProjectDetails -> ChatGPTParams
defaultChatGPTParams projectDetails =
codingChatGPTParams :: NewProjectDetails -> ChatGPTParams
codingChatGPTParams projectDetails =
GPT.ChatGPTParams
{ GPT._model = fromMaybe GPT.GPT_3_5_turbo_16k (projectDefaultGptModel $ _projectConfig projectDetails),
{ GPT._model = fromMaybe defaultCodingGptModel (projectCodingGptModel $ _projectConfig projectDetails),
GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails)
}
where
defaultCodingGptModel = GPT.GPT_3_5_turbo_0613
defaultChatGPTParamsForFixing :: NewProjectDetails -> ChatGPTParams
defaultChatGPTParamsForFixing projectDetails =
let params = defaultChatGPTParams projectDetails
in params {GPT._temperature = subtract 0.2 <$> GPT._temperature params}
planningChatGPTParams :: NewProjectDetails -> ChatGPTParams
planningChatGPTParams projectDetails =
GPT.ChatGPTParams
{ GPT._model = fromMaybe defaultPlanningGptModel (projectPlanningGptModel $ _projectConfig projectDetails),
GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails)
}
where
defaultPlanningGptModel = GPT.GPT_4_0613
fixingChatGPTParams :: ChatGPTParams -> ChatGPTParams
fixingChatGPTParams params = params {GPT._temperature = subtract 0.2 <$> GPT._temperature params}
writeToWaspFileEnd :: FilePath -> Text -> CodeAgent ()
writeToWaspFileEnd waspFilePath text = do

View File

@ -24,8 +24,8 @@ import NeatInterpolation (trimming)
import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails (..),
defaultChatGPTParams,
defaultChatGPTParamsForFixing,
codingChatGPTParams,
fixingChatGPTParams,
queryChatGPTForJSON,
writeToWaspFileEnd,
)
@ -49,7 +49,7 @@ generateAndWriteOperation operationType newProjectDetails waspFilePath plan oper
generateOperation :: OperationType -> NewProjectDetails -> [Plan.Entity] -> Plan.Operation -> CodeAgent Operation
generateOperation operationType newProjectDetails entityPlans operationPlan = do
impl <-
queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages
queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages
>>= fixOperationImplIfNeeded
return Operation {opImpl = impl, opPlan = operationPlan, opType = operationType}
where
@ -127,7 +127,7 @@ generateOperation operationType newProjectDetails entityPlans operationPlan = do
then return operationImpl
else do
let issuesText = T.pack $ intercalate "\n" ((" - " <>) <$> issues)
queryChatGPTForJSON (defaultChatGPTParamsForFixing newProjectDetails) $
queryChatGPTForJSON (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) $
chatMessages
<> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText operationImpl},
ChatMessage

View File

@ -15,7 +15,8 @@ import NeatInterpolation (trimming)
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails,
defaultChatGPTParamsForFixing,
codingChatGPTParams,
fixingChatGPTParams,
queryChatGPTForJSON,
)
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
@ -32,7 +33,7 @@ fixOperationsJsFile newProjectDetails waspFilePath opJsFilePath = do
-- with npm dependencies installed, so we skipped it for now.
fixedOpJsFile <-
queryChatGPTForJSON
(defaultChatGPTParamsForFixing newProjectDetails)
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
[ ChatMessage {role = System, content = Prompts.systemPrompt},
ChatMessage {role = User, content = fixOpJsFilePrompt currentWaspFileContent currentOpJsFileContent}
]

View File

@ -19,7 +19,7 @@ import NeatInterpolation (trimming)
import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails (..),
defaultChatGPTParams,
codingChatGPTParams,
queryChatGPTForJSON,
writeToWaspFileEnd,
)
@ -40,7 +40,7 @@ generateAndWritePage newProjectDetails waspFilePath entityPlans queries actions
generatePage :: NewProjectDetails -> [Plan.Entity] -> [Operation] -> [Operation] -> Plan.Page -> CodeAgent Page
generatePage newProjectDetails entityPlans queries actions pPlan = do
impl <- queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages
impl <- queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages
return Page {pageImpl = impl, pagePlan = pPlan}
where
chatMessages =

View File

@ -26,7 +26,8 @@ import Text.Printf (printf)
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails (..),
defaultChatGPTParamsForFixing,
codingChatGPTParams,
fixingChatGPTParams,
queryChatGPTForJSON,
)
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
@ -125,7 +126,7 @@ fixPageComponent newProjectDetails waspFilePath pageComponentPath = do
currentPageComponentContent <- fromMaybe (error "couldn't find page file to fix") <$> getFile pageComponentPath
fixedPageComponent <-
queryChatGPTForJSON
(defaultChatGPTParamsForFixing newProjectDetails)
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
[ ChatMessage {role = System, content = Prompts.systemPrompt},
ChatMessage {role = User, content = fixPageComponentPrompt currentWaspFileContent currentPageComponentContent}
]

View File

@ -25,13 +25,13 @@ import qualified Text.Parsec as Parsec
import Wasp.AI.CodeAgent (CodeAgent, writeToLog)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails (..),
defaultChatGPTParams,
defaultChatGPTParamsForFixing,
fixingChatGPTParams,
planningChatGPTParams,
queryChatGPTForJSON,
)
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
import qualified Wasp.AI.GenerateNewProject.Common.Prompts as Prompts
import Wasp.AI.OpenAI.ChatGPT (ChatGPTParams (_model), ChatMessage (..), ChatRole (..), Model (GPT_4))
import Wasp.AI.OpenAI.ChatGPT (ChatMessage (..), ChatRole (..))
import qualified Wasp.Psl.Format as Prisma
import qualified Wasp.Psl.Parser.Model as Psl.Parser
import qualified Wasp.Util.Aeson as Util.Aeson
@ -42,7 +42,7 @@ type PlanRule = String
generatePlan :: NewProjectDetails -> [PlanRule] -> CodeAgent Plan
generatePlan newProjectDetails planRules = do
writeToLog "Generating plan (slowest step, usually takes 30 to 90 seconds)..."
initialPlan <- queryChatGPTForJSON ((defaultChatGPTParams newProjectDetails) {_model = planGptModel}) chatMessages
initialPlan <- queryChatGPTForJSON (planningChatGPTParams newProjectDetails) chatMessages
writeToLog $ "Initial plan generated!\n" <> summarizePlan initialPlan
writeToLog "Fixing initial plan..."
fixedPlan <- fixPlanRepeatedly 3 initialPlan
@ -170,7 +170,7 @@ generatePlan newProjectDetails planRules = do
|]
writeToLog "Sending plan to GPT for fixing..."
fixedPlan <-
queryChatGPTForJSON ((defaultChatGPTParamsForFixing newProjectDetails) {_model = planGptModel}) $
queryChatGPTForJSON (fixingChatGPTParams $ planningChatGPTParams newProjectDetails) $
chatMessages
<> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText plan'},
ChatMessage
@ -190,8 +190,6 @@ generatePlan newProjectDetails planRules = do
]
return (False, fixedPlan)
planGptModel = GPT_4
checkPlanForEntityIssues :: Plan -> [String]
checkPlanForEntityIssues plan =
checkNumEntities

View File

@ -18,7 +18,8 @@ import NeatInterpolation (trimming)
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
import Wasp.AI.GenerateNewProject.Common
( NewProjectDetails,
defaultChatGPTParamsForFixing,
codingChatGPTParams,
fixingChatGPTParams,
queryChatGPTForJSON,
)
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
@ -51,7 +52,7 @@ fixWaspFile newProjectDetails waspFilePath plan = do
OnlyIfCompileErrors | null compileErrors -> return $ WaspFile {waspFileContent = wfContent}
_otherwise ->
queryChatGPTForJSON
(defaultChatGPTParamsForFixing newProjectDetails)
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
[ ChatMessage {role = System, content = Prompts.systemPrompt},
ChatMessage {role = User, content = fixWaspFilePrompt wfContent compileErrors}
]

View File

@ -88,7 +88,17 @@ data ChatGPTParams = ChatGPTParams
deriving (Show)
-- TODO: There are some more data models there but for now we went with these core ones.
data Model = GPT_3_5_turbo | GPT_3_5_turbo_16k | GPT_4
data Model
= GPT_3_5_turbo_1106
| GPT_3_5_turbo
| GPT_3_5_turbo_16k
| GPT_3_5_turbo_0613
| GPT_3_5_turbo_16k_0613
| GPT_4_1106_Preview
| GPT_4
| GPT_4_32k
| GPT_4_0613
| GPT_4_32k_0613
deriving (Eq, Bounded, Enum)
instance Show Model where
@ -96,9 +106,16 @@ instance Show Model where
modelOpenAiId :: Model -> String
modelOpenAiId = \case
GPT_3_5_turbo_1106 -> "gpt-3.5-turbo-1106"
GPT_3_5_turbo -> "gpt-3.5-turbo"
GPT_3_5_turbo_16k -> "gpt-3.5-turbo-16k"
GPT_3_5_turbo_0613 -> "gpt-3.5-turbo-0613"
GPT_3_5_turbo_16k_0613 -> "gpt-3.5-turbo-16k-0613"
GPT_4_1106_Preview -> "gpt-4-1106-preview"
GPT_4 -> "gpt-4"
GPT_4_32k -> "gpt-4-32k"
GPT_4_0613 -> "gpt-4-0613"
GPT_4_32k_0613 -> "gpt-4-32k-0613"
instance FromJSON Model where
parseJSON = Aeson.withText "Model" $ \t ->