From 34f65395908e0d9cd28b62cda9903fa9f3babbde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janko=20Vidakovi=C4=87?= <58572059+jankovidakovic@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:11:08 +0100 Subject: [PATCH] Support more GPT models in Mage + allow customizing both of them (#1598) * Enable new OpenAI models for MAGE * Add option to specify Plan GPT model from CLI * Fix error in parsing planGptModel * Remove check for GPT-4 availability * Refactor initialization of GPT params * Add GPT 3.5 Turbo 0613 * Fix GPT versions to exact values Generic OpenAI models, such as 'GPT-4' and 'GPT-3_5_turbo', point to the 'latest' of the respective OpenAI models. This means that the exact models used can change unexpectedly. This would not normally be an issue, but the newer models tend to exhibit a performance drop. Therefore, to future-proof our application, we fix the model versions to prevent any unexpected model changes. * Rename base->coding, plan->planning * Remove unused _useGpt3IfGpt4NotAvailable field * Remove confusing helper function * Fix typo Planing -> Planning --- .../Wasp/Cli/Command/CreateNewProject/AI.hs | 6 +-- waspc/src/Wasp/AI/CodeAgent.hs | 26 +----------- .../src/Wasp/AI/GenerateNewProject/Common.hs | 40 +++++++++++++------ .../Wasp/AI/GenerateNewProject/Operation.hs | 8 ++-- .../AI/GenerateNewProject/OperationsJsFile.hs | 5 ++- waspc/src/Wasp/AI/GenerateNewProject/Page.hs | 4 +- .../GenerateNewProject/PageComponentFile.hs | 5 ++- waspc/src/Wasp/AI/GenerateNewProject/Plan.hs | 12 +++--- .../Wasp/AI/GenerateNewProject/WaspFile.hs | 5 ++- waspc/src/Wasp/AI/OpenAI/ChatGPT.hs | 19 ++++++++- 10 files changed, 69 insertions(+), 61 deletions(-) diff --git a/waspc/cli/src/Wasp/Cli/Command/CreateNewProject/AI.hs b/waspc/cli/src/Wasp/Cli/Command/CreateNewProject/AI.hs index 4bd9b71f5..91a8b5930 100644 --- a/waspc/cli/src/Wasp/Cli/Command/CreateNewProject/AI.hs +++ b/waspc/cli/src/Wasp/Cli/Command/CreateNewProject/AI.hs @@ -61,8 +61,7 @@ createNewProjectOnDisk openAIApiKey waspProjectDir appName appDescription projec CA.CodeAgentConfig { CA._openAIApiKey = openAIApiKey, CA._writeFile = writeFileToDisk, - CA._writeLog = forwardLogToStdout, - CA._useGpt3IfGpt4NotAvailable = True + CA._writeLog = forwardLogToStdout } writeFileToDisk path content = do @@ -93,8 +92,7 @@ createNewProjectNonInteractiveToStdout projectName appDescription projectConfigJ CA.CodeAgentConfig { CA._openAIApiKey = openAIApiKey, CA._writeFile = writeFileToStdoutWithDelimiters, - CA._writeLog = writeLogToStdoutWithDelimiters, - CA._useGpt3IfGpt4NotAvailable = True + CA._writeLog = writeLogToStdoutWithDelimiters } liftIO $ generateNewProject codeAgentConfig appName appDescription projectConfig diff --git a/waspc/src/Wasp/AI/CodeAgent.hs b/waspc/src/Wasp/AI/CodeAgent.hs index 33454f8e2..35a6a35e4 100644 --- a/waspc/src/Wasp/AI/CodeAgent.hs +++ b/waspc/src/Wasp/AI/CodeAgent.hs @@ -12,7 +12,6 @@ module Wasp.AI.CodeAgent queryChatGPT, getTotalTokensUsage, getOpenAIApiKey, - checkIfGpt4IsAvailable, ) where @@ -42,8 +41,7 @@ newtype CodeAgent a = CodeAgent {_unCodeAgent :: ReaderT CodeAgentConfig (StateT data CodeAgentConfig = CodeAgentConfig { _openAIApiKey :: !OpenAIApiKey, _writeFile :: !(FilePath -> Text -> IO ()), -- TODO: Use StrongPath? Not clear which kind of path is it, rel, abs, ... . - _writeLog :: !(Text -> IO ()), - _useGpt3IfGpt4NotAvailable :: !Bool + _writeLog :: !(Text -> IO ()) } instance MonadRetry CodeAgent where @@ -101,18 +99,8 @@ getAllFiles = gets $ H.toList . _files queryChatGPT :: ChatGPTParams -> [ChatMessage] -> CodeAgent Text queryChatGPT params messages = do - params' <- do - useGpt3IfGpt4NotAvailable <- asks _useGpt3IfGpt4NotAvailable - if ChatGPT._model params == ChatGPT.GPT_4 && useGpt3IfGpt4NotAvailable - then do - isAvailable <- checkIfGpt4IsAvailable - if not isAvailable - then return $ params {ChatGPT._model = ChatGPT.GPT_3_5_turbo_16k} - else return params - else return params - key <- asks _openAIApiKey - chatResponse <- queryChatGPTWithRetry key params' messages + chatResponse <- queryChatGPTWithRetry key params messages modify $ \s -> s {_usage = _usage s <> [ChatGPT.usage chatResponse]} return $ ChatGPT.getChatResponseContent chatResponse where @@ -138,16 +126,6 @@ queryChatGPT params messages = do getOpenAIApiKey :: CodeAgent OpenAIApiKey getOpenAIApiKey = asks _openAIApiKey -checkIfGpt4IsAvailable :: CodeAgent Bool -checkIfGpt4IsAvailable = do - gets _isGpt4Available >>= \case - Just isAvailable -> pure isAvailable - Nothing -> do - key <- asks _openAIApiKey - isAvailable <- liftIO $ ChatGPT.checkIfGpt4IsAvailable key - modify $ \s -> s {_isGpt4Available = Just isAvailable} - return isAvailable - type NumTokens = Int -- | Returns total tokens usage: (, ). diff --git a/waspc/src/Wasp/AI/GenerateNewProject/Common.hs b/waspc/src/Wasp/AI/GenerateNewProject/Common.hs index 9e7d8fe5f..9cad527ab 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/Common.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/Common.hs @@ -7,9 +7,10 @@ module Wasp.AI.GenerateNewProject.Common getProjectPrimaryColor, emptyNewProjectConfig, queryChatGPTForJSON, - defaultChatGPTParams, - defaultChatGPTParamsForFixing, writeToWaspFileEnd, + planningChatGPTParams, + codingChatGPTParams, + fixingChatGPTParams, ) where @@ -33,7 +34,8 @@ data NewProjectConfig = NewProjectConfig { projectAuth :: !(Maybe AuthProvider), -- One of the Tailwind color names: https://tailwindcss.com/docs/customizing-colors projectPrimaryColor :: !(Maybe String), - projectDefaultGptModel :: !(Maybe GPT.Model), + projectCodingGptModel :: !(Maybe GPT.Model), + projectPlanningGptModel :: !(Maybe GPT.Model), projectDefaultGptTemperature :: !(Maybe Float) } deriving (Show) @@ -42,13 +44,15 @@ instance Aeson.FromJSON NewProjectConfig where parseJSON = withObject "NewProjectConfig" $ \obj -> do auth <- obj .:? "auth" primaryColor <- obj .:? "primaryColor" - defaultGptModel <- obj .:? "defaultGptModel" + codingGptModel <- obj .:? "codingGptModel" + planningGptModel <- obj .:? "planningGptModel" defaultGptTemperature <- obj .:? "defaultGptTemperature" return ( NewProjectConfig { projectAuth = auth, projectPrimaryColor = primaryColor, - projectDefaultGptModel = defaultGptModel, + projectCodingGptModel = codingGptModel, + projectPlanningGptModel = planningGptModel, projectDefaultGptTemperature = defaultGptTemperature } ) @@ -58,7 +62,8 @@ emptyNewProjectConfig = NewProjectConfig { projectAuth = Nothing, projectPrimaryColor = Nothing, - projectDefaultGptModel = Nothing, + projectCodingGptModel = Nothing, + projectPlanningGptModel = Nothing, projectDefaultGptTemperature = Nothing } @@ -133,17 +138,26 @@ queryChatGPTForJSON chatGPTParams initChatMsgs = doQueryForJSON 0 0 initChatMsgs maxNumFailuresPerRunBeforeGivingUpOnARun = 2 maxNumFailedRunsBeforeGivingUpCompletely = 2 -defaultChatGPTParams :: NewProjectDetails -> ChatGPTParams -defaultChatGPTParams projectDetails = +codingChatGPTParams :: NewProjectDetails -> ChatGPTParams +codingChatGPTParams projectDetails = GPT.ChatGPTParams - { GPT._model = fromMaybe GPT.GPT_3_5_turbo_16k (projectDefaultGptModel $ _projectConfig projectDetails), + { GPT._model = fromMaybe defaultCodingGptModel (projectCodingGptModel $ _projectConfig projectDetails), GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails) } + where + defaultCodingGptModel = GPT.GPT_3_5_turbo_0613 -defaultChatGPTParamsForFixing :: NewProjectDetails -> ChatGPTParams -defaultChatGPTParamsForFixing projectDetails = - let params = defaultChatGPTParams projectDetails - in params {GPT._temperature = subtract 0.2 <$> GPT._temperature params} +planningChatGPTParams :: NewProjectDetails -> ChatGPTParams +planningChatGPTParams projectDetails = + GPT.ChatGPTParams + { GPT._model = fromMaybe defaultPlanningGptModel (projectPlanningGptModel $ _projectConfig projectDetails), + GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails) + } + where + defaultPlanningGptModel = GPT.GPT_4_0613 + +fixingChatGPTParams :: ChatGPTParams -> ChatGPTParams +fixingChatGPTParams params = params {GPT._temperature = subtract 0.2 <$> GPT._temperature params} writeToWaspFileEnd :: FilePath -> Text -> CodeAgent () writeToWaspFileEnd waspFilePath text = do diff --git a/waspc/src/Wasp/AI/GenerateNewProject/Operation.hs b/waspc/src/Wasp/AI/GenerateNewProject/Operation.hs index b20bbc0e8..de7de7ee0 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/Operation.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/Operation.hs @@ -24,8 +24,8 @@ import NeatInterpolation (trimming) import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails (..), - defaultChatGPTParams, - defaultChatGPTParamsForFixing, + codingChatGPTParams, + fixingChatGPTParams, queryChatGPTForJSON, writeToWaspFileEnd, ) @@ -49,7 +49,7 @@ generateAndWriteOperation operationType newProjectDetails waspFilePath plan oper generateOperation :: OperationType -> NewProjectDetails -> [Plan.Entity] -> Plan.Operation -> CodeAgent Operation generateOperation operationType newProjectDetails entityPlans operationPlan = do impl <- - queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages + queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages >>= fixOperationImplIfNeeded return Operation {opImpl = impl, opPlan = operationPlan, opType = operationType} where @@ -127,7 +127,7 @@ generateOperation operationType newProjectDetails entityPlans operationPlan = do then return operationImpl else do let issuesText = T.pack $ intercalate "\n" ((" - " <>) <$> issues) - queryChatGPTForJSON (defaultChatGPTParamsForFixing newProjectDetails) $ + queryChatGPTForJSON (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) $ chatMessages <> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText operationImpl}, ChatMessage diff --git a/waspc/src/Wasp/AI/GenerateNewProject/OperationsJsFile.hs b/waspc/src/Wasp/AI/GenerateNewProject/OperationsJsFile.hs index 37f675a47..7f5e48884 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/OperationsJsFile.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/OperationsJsFile.hs @@ -15,7 +15,8 @@ import NeatInterpolation (trimming) import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails, - defaultChatGPTParamsForFixing, + codingChatGPTParams, + fixingChatGPTParams, queryChatGPTForJSON, ) import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock) @@ -32,7 +33,7 @@ fixOperationsJsFile newProjectDetails waspFilePath opJsFilePath = do -- with npm dependencies installed, so we skipped it for now. fixedOpJsFile <- queryChatGPTForJSON - (defaultChatGPTParamsForFixing newProjectDetails) + (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) [ ChatMessage {role = System, content = Prompts.systemPrompt}, ChatMessage {role = User, content = fixOpJsFilePrompt currentWaspFileContent currentOpJsFileContent} ] diff --git a/waspc/src/Wasp/AI/GenerateNewProject/Page.hs b/waspc/src/Wasp/AI/GenerateNewProject/Page.hs index b5df94924..77baf8a99 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/Page.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/Page.hs @@ -19,7 +19,7 @@ import NeatInterpolation (trimming) import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails (..), - defaultChatGPTParams, + codingChatGPTParams, queryChatGPTForJSON, writeToWaspFileEnd, ) @@ -40,7 +40,7 @@ generateAndWritePage newProjectDetails waspFilePath entityPlans queries actions generatePage :: NewProjectDetails -> [Plan.Entity] -> [Operation] -> [Operation] -> Plan.Page -> CodeAgent Page generatePage newProjectDetails entityPlans queries actions pPlan = do - impl <- queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages + impl <- queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages return Page {pageImpl = impl, pagePlan = pPlan} where chatMessages = diff --git a/waspc/src/Wasp/AI/GenerateNewProject/PageComponentFile.hs b/waspc/src/Wasp/AI/GenerateNewProject/PageComponentFile.hs index 21cfc3999..8e74a0a7a 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/PageComponentFile.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/PageComponentFile.hs @@ -26,7 +26,8 @@ import Text.Printf (printf) import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails (..), - defaultChatGPTParamsForFixing, + codingChatGPTParams, + fixingChatGPTParams, queryChatGPTForJSON, ) import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock) @@ -125,7 +126,7 @@ fixPageComponent newProjectDetails waspFilePath pageComponentPath = do currentPageComponentContent <- fromMaybe (error "couldn't find page file to fix") <$> getFile pageComponentPath fixedPageComponent <- queryChatGPTForJSON - (defaultChatGPTParamsForFixing newProjectDetails) + (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) [ ChatMessage {role = System, content = Prompts.systemPrompt}, ChatMessage {role = User, content = fixPageComponentPrompt currentWaspFileContent currentPageComponentContent} ] diff --git a/waspc/src/Wasp/AI/GenerateNewProject/Plan.hs b/waspc/src/Wasp/AI/GenerateNewProject/Plan.hs index 5ee9c5e8b..4c5f22aed 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/Plan.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/Plan.hs @@ -25,13 +25,13 @@ import qualified Text.Parsec as Parsec import Wasp.AI.CodeAgent (CodeAgent, writeToLog) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails (..), - defaultChatGPTParams, - defaultChatGPTParamsForFixing, + fixingChatGPTParams, + planningChatGPTParams, queryChatGPTForJSON, ) import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock) import qualified Wasp.AI.GenerateNewProject.Common.Prompts as Prompts -import Wasp.AI.OpenAI.ChatGPT (ChatGPTParams (_model), ChatMessage (..), ChatRole (..), Model (GPT_4)) +import Wasp.AI.OpenAI.ChatGPT (ChatMessage (..), ChatRole (..)) import qualified Wasp.Psl.Format as Prisma import qualified Wasp.Psl.Parser.Model as Psl.Parser import qualified Wasp.Util.Aeson as Util.Aeson @@ -42,7 +42,7 @@ type PlanRule = String generatePlan :: NewProjectDetails -> [PlanRule] -> CodeAgent Plan generatePlan newProjectDetails planRules = do writeToLog "Generating plan (slowest step, usually takes 30 to 90 seconds)..." - initialPlan <- queryChatGPTForJSON ((defaultChatGPTParams newProjectDetails) {_model = planGptModel}) chatMessages + initialPlan <- queryChatGPTForJSON (planningChatGPTParams newProjectDetails) chatMessages writeToLog $ "Initial plan generated!\n" <> summarizePlan initialPlan writeToLog "Fixing initial plan..." fixedPlan <- fixPlanRepeatedly 3 initialPlan @@ -170,7 +170,7 @@ generatePlan newProjectDetails planRules = do |] writeToLog "Sending plan to GPT for fixing..." fixedPlan <- - queryChatGPTForJSON ((defaultChatGPTParamsForFixing newProjectDetails) {_model = planGptModel}) $ + queryChatGPTForJSON (fixingChatGPTParams $ planningChatGPTParams newProjectDetails) $ chatMessages <> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText plan'}, ChatMessage @@ -190,8 +190,6 @@ generatePlan newProjectDetails planRules = do ] return (False, fixedPlan) - planGptModel = GPT_4 - checkPlanForEntityIssues :: Plan -> [String] checkPlanForEntityIssues plan = checkNumEntities diff --git a/waspc/src/Wasp/AI/GenerateNewProject/WaspFile.hs b/waspc/src/Wasp/AI/GenerateNewProject/WaspFile.hs index 8c38efaa6..2089ad07d 100644 --- a/waspc/src/Wasp/AI/GenerateNewProject/WaspFile.hs +++ b/waspc/src/Wasp/AI/GenerateNewProject/WaspFile.hs @@ -18,7 +18,8 @@ import NeatInterpolation (trimming) import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile) import Wasp.AI.GenerateNewProject.Common ( NewProjectDetails, - defaultChatGPTParamsForFixing, + codingChatGPTParams, + fixingChatGPTParams, queryChatGPTForJSON, ) import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock) @@ -51,7 +52,7 @@ fixWaspFile newProjectDetails waspFilePath plan = do OnlyIfCompileErrors | null compileErrors -> return $ WaspFile {waspFileContent = wfContent} _otherwise -> queryChatGPTForJSON - (defaultChatGPTParamsForFixing newProjectDetails) + (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) [ ChatMessage {role = System, content = Prompts.systemPrompt}, ChatMessage {role = User, content = fixWaspFilePrompt wfContent compileErrors} ] diff --git a/waspc/src/Wasp/AI/OpenAI/ChatGPT.hs b/waspc/src/Wasp/AI/OpenAI/ChatGPT.hs index fa50898bb..5a0981b5f 100644 --- a/waspc/src/Wasp/AI/OpenAI/ChatGPT.hs +++ b/waspc/src/Wasp/AI/OpenAI/ChatGPT.hs @@ -88,7 +88,17 @@ data ChatGPTParams = ChatGPTParams deriving (Show) -- TODO: There are some more data models there but for now we went with these core ones. -data Model = GPT_3_5_turbo | GPT_3_5_turbo_16k | GPT_4 +data Model + = GPT_3_5_turbo_1106 + | GPT_3_5_turbo + | GPT_3_5_turbo_16k + | GPT_3_5_turbo_0613 + | GPT_3_5_turbo_16k_0613 + | GPT_4_1106_Preview + | GPT_4 + | GPT_4_32k + | GPT_4_0613 + | GPT_4_32k_0613 deriving (Eq, Bounded, Enum) instance Show Model where @@ -96,9 +106,16 @@ instance Show Model where modelOpenAiId :: Model -> String modelOpenAiId = \case + GPT_3_5_turbo_1106 -> "gpt-3.5-turbo-1106" GPT_3_5_turbo -> "gpt-3.5-turbo" GPT_3_5_turbo_16k -> "gpt-3.5-turbo-16k" + GPT_3_5_turbo_0613 -> "gpt-3.5-turbo-0613" + GPT_3_5_turbo_16k_0613 -> "gpt-3.5-turbo-16k-0613" + GPT_4_1106_Preview -> "gpt-4-1106-preview" GPT_4 -> "gpt-4" + GPT_4_32k -> "gpt-4-32k" + GPT_4_0613 -> "gpt-4-0613" + GPT_4_32k_0613 -> "gpt-4-32k-0613" instance FromJSON Model where parseJSON = Aeson.withText "Model" $ \t ->