mirror of
https://github.com/wasp-lang/wasp.git
synced 2024-11-22 09:33:45 +03:00
Support more GPT models in Mage + allow customizing both of them (#1598)
* Enable new OpenAI models for MAGE * Add option to specify Plan GPT model from CLI * Fix error in parsing planGptModel * Remove check for GPT-4 availability * Refactor initialization of GPT params * Add GPT 3.5 Turbo 0613 * Fix GPT versions to exact values Generic OpenAI models, such as 'GPT-4' and 'GPT-3_5_turbo', point to the 'latest' of the respective OpenAI models. This means that the exact models used can change unexpectedly. This would not normally be an issue, but the newer models tend to exhibit a performance drop. Therefore, to future-proof our application, we fix the model versions to prevent any unexpected model changes. * Rename base->coding, plan->planning * Remove unused _useGpt3IfGpt4NotAvailable field * Remove confusing helper function * Fix typo Planing -> Planning
This commit is contained in:
parent
548b5a42bd
commit
34f6539590
@ -61,8 +61,7 @@ createNewProjectOnDisk openAIApiKey waspProjectDir appName appDescription projec
|
||||
CA.CodeAgentConfig
|
||||
{ CA._openAIApiKey = openAIApiKey,
|
||||
CA._writeFile = writeFileToDisk,
|
||||
CA._writeLog = forwardLogToStdout,
|
||||
CA._useGpt3IfGpt4NotAvailable = True
|
||||
CA._writeLog = forwardLogToStdout
|
||||
}
|
||||
|
||||
writeFileToDisk path content = do
|
||||
@ -93,8 +92,7 @@ createNewProjectNonInteractiveToStdout projectName appDescription projectConfigJ
|
||||
CA.CodeAgentConfig
|
||||
{ CA._openAIApiKey = openAIApiKey,
|
||||
CA._writeFile = writeFileToStdoutWithDelimiters,
|
||||
CA._writeLog = writeLogToStdoutWithDelimiters,
|
||||
CA._useGpt3IfGpt4NotAvailable = True
|
||||
CA._writeLog = writeLogToStdoutWithDelimiters
|
||||
}
|
||||
|
||||
liftIO $ generateNewProject codeAgentConfig appName appDescription projectConfig
|
||||
|
@ -12,7 +12,6 @@ module Wasp.AI.CodeAgent
|
||||
queryChatGPT,
|
||||
getTotalTokensUsage,
|
||||
getOpenAIApiKey,
|
||||
checkIfGpt4IsAvailable,
|
||||
)
|
||||
where
|
||||
|
||||
@ -42,8 +41,7 @@ newtype CodeAgent a = CodeAgent {_unCodeAgent :: ReaderT CodeAgentConfig (StateT
|
||||
data CodeAgentConfig = CodeAgentConfig
|
||||
{ _openAIApiKey :: !OpenAIApiKey,
|
||||
_writeFile :: !(FilePath -> Text -> IO ()), -- TODO: Use StrongPath? Not clear which kind of path is it, rel, abs, ... .
|
||||
_writeLog :: !(Text -> IO ()),
|
||||
_useGpt3IfGpt4NotAvailable :: !Bool
|
||||
_writeLog :: !(Text -> IO ())
|
||||
}
|
||||
|
||||
instance MonadRetry CodeAgent where
|
||||
@ -101,18 +99,8 @@ getAllFiles = gets $ H.toList . _files
|
||||
|
||||
queryChatGPT :: ChatGPTParams -> [ChatMessage] -> CodeAgent Text
|
||||
queryChatGPT params messages = do
|
||||
params' <- do
|
||||
useGpt3IfGpt4NotAvailable <- asks _useGpt3IfGpt4NotAvailable
|
||||
if ChatGPT._model params == ChatGPT.GPT_4 && useGpt3IfGpt4NotAvailable
|
||||
then do
|
||||
isAvailable <- checkIfGpt4IsAvailable
|
||||
if not isAvailable
|
||||
then return $ params {ChatGPT._model = ChatGPT.GPT_3_5_turbo_16k}
|
||||
else return params
|
||||
else return params
|
||||
|
||||
key <- asks _openAIApiKey
|
||||
chatResponse <- queryChatGPTWithRetry key params' messages
|
||||
chatResponse <- queryChatGPTWithRetry key params messages
|
||||
modify $ \s -> s {_usage = _usage s <> [ChatGPT.usage chatResponse]}
|
||||
return $ ChatGPT.getChatResponseContent chatResponse
|
||||
where
|
||||
@ -138,16 +126,6 @@ queryChatGPT params messages = do
|
||||
getOpenAIApiKey :: CodeAgent OpenAIApiKey
|
||||
getOpenAIApiKey = asks _openAIApiKey
|
||||
|
||||
checkIfGpt4IsAvailable :: CodeAgent Bool
|
||||
checkIfGpt4IsAvailable = do
|
||||
gets _isGpt4Available >>= \case
|
||||
Just isAvailable -> pure isAvailable
|
||||
Nothing -> do
|
||||
key <- asks _openAIApiKey
|
||||
isAvailable <- liftIO $ ChatGPT.checkIfGpt4IsAvailable key
|
||||
modify $ \s -> s {_isGpt4Available = Just isAvailable}
|
||||
return isAvailable
|
||||
|
||||
type NumTokens = Int
|
||||
|
||||
-- | Returns total tokens usage: (<num_prompt_tokens>, <num_completion_tokens>).
|
||||
|
@ -7,9 +7,10 @@ module Wasp.AI.GenerateNewProject.Common
|
||||
getProjectPrimaryColor,
|
||||
emptyNewProjectConfig,
|
||||
queryChatGPTForJSON,
|
||||
defaultChatGPTParams,
|
||||
defaultChatGPTParamsForFixing,
|
||||
writeToWaspFileEnd,
|
||||
planningChatGPTParams,
|
||||
codingChatGPTParams,
|
||||
fixingChatGPTParams,
|
||||
)
|
||||
where
|
||||
|
||||
@ -33,7 +34,8 @@ data NewProjectConfig = NewProjectConfig
|
||||
{ projectAuth :: !(Maybe AuthProvider),
|
||||
-- One of the Tailwind color names: https://tailwindcss.com/docs/customizing-colors
|
||||
projectPrimaryColor :: !(Maybe String),
|
||||
projectDefaultGptModel :: !(Maybe GPT.Model),
|
||||
projectCodingGptModel :: !(Maybe GPT.Model),
|
||||
projectPlanningGptModel :: !(Maybe GPT.Model),
|
||||
projectDefaultGptTemperature :: !(Maybe Float)
|
||||
}
|
||||
deriving (Show)
|
||||
@ -42,13 +44,15 @@ instance Aeson.FromJSON NewProjectConfig where
|
||||
parseJSON = withObject "NewProjectConfig" $ \obj -> do
|
||||
auth <- obj .:? "auth"
|
||||
primaryColor <- obj .:? "primaryColor"
|
||||
defaultGptModel <- obj .:? "defaultGptModel"
|
||||
codingGptModel <- obj .:? "codingGptModel"
|
||||
planningGptModel <- obj .:? "planningGptModel"
|
||||
defaultGptTemperature <- obj .:? "defaultGptTemperature"
|
||||
return
|
||||
( NewProjectConfig
|
||||
{ projectAuth = auth,
|
||||
projectPrimaryColor = primaryColor,
|
||||
projectDefaultGptModel = defaultGptModel,
|
||||
projectCodingGptModel = codingGptModel,
|
||||
projectPlanningGptModel = planningGptModel,
|
||||
projectDefaultGptTemperature = defaultGptTemperature
|
||||
}
|
||||
)
|
||||
@ -58,7 +62,8 @@ emptyNewProjectConfig =
|
||||
NewProjectConfig
|
||||
{ projectAuth = Nothing,
|
||||
projectPrimaryColor = Nothing,
|
||||
projectDefaultGptModel = Nothing,
|
||||
projectCodingGptModel = Nothing,
|
||||
projectPlanningGptModel = Nothing,
|
||||
projectDefaultGptTemperature = Nothing
|
||||
}
|
||||
|
||||
@ -133,17 +138,26 @@ queryChatGPTForJSON chatGPTParams initChatMsgs = doQueryForJSON 0 0 initChatMsgs
|
||||
maxNumFailuresPerRunBeforeGivingUpOnARun = 2
|
||||
maxNumFailedRunsBeforeGivingUpCompletely = 2
|
||||
|
||||
defaultChatGPTParams :: NewProjectDetails -> ChatGPTParams
|
||||
defaultChatGPTParams projectDetails =
|
||||
codingChatGPTParams :: NewProjectDetails -> ChatGPTParams
|
||||
codingChatGPTParams projectDetails =
|
||||
GPT.ChatGPTParams
|
||||
{ GPT._model = fromMaybe GPT.GPT_3_5_turbo_16k (projectDefaultGptModel $ _projectConfig projectDetails),
|
||||
{ GPT._model = fromMaybe defaultCodingGptModel (projectCodingGptModel $ _projectConfig projectDetails),
|
||||
GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails)
|
||||
}
|
||||
where
|
||||
defaultCodingGptModel = GPT.GPT_3_5_turbo_0613
|
||||
|
||||
defaultChatGPTParamsForFixing :: NewProjectDetails -> ChatGPTParams
|
||||
defaultChatGPTParamsForFixing projectDetails =
|
||||
let params = defaultChatGPTParams projectDetails
|
||||
in params {GPT._temperature = subtract 0.2 <$> GPT._temperature params}
|
||||
planningChatGPTParams :: NewProjectDetails -> ChatGPTParams
|
||||
planningChatGPTParams projectDetails =
|
||||
GPT.ChatGPTParams
|
||||
{ GPT._model = fromMaybe defaultPlanningGptModel (projectPlanningGptModel $ _projectConfig projectDetails),
|
||||
GPT._temperature = Just $ fromMaybe 0.7 (projectDefaultGptTemperature $ _projectConfig projectDetails)
|
||||
}
|
||||
where
|
||||
defaultPlanningGptModel = GPT.GPT_4_0613
|
||||
|
||||
fixingChatGPTParams :: ChatGPTParams -> ChatGPTParams
|
||||
fixingChatGPTParams params = params {GPT._temperature = subtract 0.2 <$> GPT._temperature params}
|
||||
|
||||
writeToWaspFileEnd :: FilePath -> Text -> CodeAgent ()
|
||||
writeToWaspFileEnd waspFilePath text = do
|
||||
|
@ -24,8 +24,8 @@ import NeatInterpolation (trimming)
|
||||
import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails (..),
|
||||
defaultChatGPTParams,
|
||||
defaultChatGPTParamsForFixing,
|
||||
codingChatGPTParams,
|
||||
fixingChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
writeToWaspFileEnd,
|
||||
)
|
||||
@ -49,7 +49,7 @@ generateAndWriteOperation operationType newProjectDetails waspFilePath plan oper
|
||||
generateOperation :: OperationType -> NewProjectDetails -> [Plan.Entity] -> Plan.Operation -> CodeAgent Operation
|
||||
generateOperation operationType newProjectDetails entityPlans operationPlan = do
|
||||
impl <-
|
||||
queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages
|
||||
queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages
|
||||
>>= fixOperationImplIfNeeded
|
||||
return Operation {opImpl = impl, opPlan = operationPlan, opType = operationType}
|
||||
where
|
||||
@ -127,7 +127,7 @@ generateOperation operationType newProjectDetails entityPlans operationPlan = do
|
||||
then return operationImpl
|
||||
else do
|
||||
let issuesText = T.pack $ intercalate "\n" ((" - " <>) <$> issues)
|
||||
queryChatGPTForJSON (defaultChatGPTParamsForFixing newProjectDetails) $
|
||||
queryChatGPTForJSON (fixingChatGPTParams $ codingChatGPTParams newProjectDetails) $
|
||||
chatMessages
|
||||
<> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText operationImpl},
|
||||
ChatMessage
|
||||
|
@ -15,7 +15,8 @@ import NeatInterpolation (trimming)
|
||||
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails,
|
||||
defaultChatGPTParamsForFixing,
|
||||
codingChatGPTParams,
|
||||
fixingChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
)
|
||||
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
|
||||
@ -32,7 +33,7 @@ fixOperationsJsFile newProjectDetails waspFilePath opJsFilePath = do
|
||||
-- with npm dependencies installed, so we skipped it for now.
|
||||
fixedOpJsFile <-
|
||||
queryChatGPTForJSON
|
||||
(defaultChatGPTParamsForFixing newProjectDetails)
|
||||
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
|
||||
[ ChatMessage {role = System, content = Prompts.systemPrompt},
|
||||
ChatMessage {role = User, content = fixOpJsFilePrompt currentWaspFileContent currentOpJsFileContent}
|
||||
]
|
||||
|
@ -19,7 +19,7 @@ import NeatInterpolation (trimming)
|
||||
import Wasp.AI.CodeAgent (CodeAgent, writeToFile, writeToLog)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails (..),
|
||||
defaultChatGPTParams,
|
||||
codingChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
writeToWaspFileEnd,
|
||||
)
|
||||
@ -40,7 +40,7 @@ generateAndWritePage newProjectDetails waspFilePath entityPlans queries actions
|
||||
|
||||
generatePage :: NewProjectDetails -> [Plan.Entity] -> [Operation] -> [Operation] -> Plan.Page -> CodeAgent Page
|
||||
generatePage newProjectDetails entityPlans queries actions pPlan = do
|
||||
impl <- queryChatGPTForJSON (defaultChatGPTParams newProjectDetails) chatMessages
|
||||
impl <- queryChatGPTForJSON (codingChatGPTParams newProjectDetails) chatMessages
|
||||
return Page {pageImpl = impl, pagePlan = pPlan}
|
||||
where
|
||||
chatMessages =
|
||||
|
@ -26,7 +26,8 @@ import Text.Printf (printf)
|
||||
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails (..),
|
||||
defaultChatGPTParamsForFixing,
|
||||
codingChatGPTParams,
|
||||
fixingChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
)
|
||||
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
|
||||
@ -125,7 +126,7 @@ fixPageComponent newProjectDetails waspFilePath pageComponentPath = do
|
||||
currentPageComponentContent <- fromMaybe (error "couldn't find page file to fix") <$> getFile pageComponentPath
|
||||
fixedPageComponent <-
|
||||
queryChatGPTForJSON
|
||||
(defaultChatGPTParamsForFixing newProjectDetails)
|
||||
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
|
||||
[ ChatMessage {role = System, content = Prompts.systemPrompt},
|
||||
ChatMessage {role = User, content = fixPageComponentPrompt currentWaspFileContent currentPageComponentContent}
|
||||
]
|
||||
|
@ -25,13 +25,13 @@ import qualified Text.Parsec as Parsec
|
||||
import Wasp.AI.CodeAgent (CodeAgent, writeToLog)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails (..),
|
||||
defaultChatGPTParams,
|
||||
defaultChatGPTParamsForFixing,
|
||||
fixingChatGPTParams,
|
||||
planningChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
)
|
||||
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
|
||||
import qualified Wasp.AI.GenerateNewProject.Common.Prompts as Prompts
|
||||
import Wasp.AI.OpenAI.ChatGPT (ChatGPTParams (_model), ChatMessage (..), ChatRole (..), Model (GPT_4))
|
||||
import Wasp.AI.OpenAI.ChatGPT (ChatMessage (..), ChatRole (..))
|
||||
import qualified Wasp.Psl.Format as Prisma
|
||||
import qualified Wasp.Psl.Parser.Model as Psl.Parser
|
||||
import qualified Wasp.Util.Aeson as Util.Aeson
|
||||
@ -42,7 +42,7 @@ type PlanRule = String
|
||||
generatePlan :: NewProjectDetails -> [PlanRule] -> CodeAgent Plan
|
||||
generatePlan newProjectDetails planRules = do
|
||||
writeToLog "Generating plan (slowest step, usually takes 30 to 90 seconds)..."
|
||||
initialPlan <- queryChatGPTForJSON ((defaultChatGPTParams newProjectDetails) {_model = planGptModel}) chatMessages
|
||||
initialPlan <- queryChatGPTForJSON (planningChatGPTParams newProjectDetails) chatMessages
|
||||
writeToLog $ "Initial plan generated!\n" <> summarizePlan initialPlan
|
||||
writeToLog "Fixing initial plan..."
|
||||
fixedPlan <- fixPlanRepeatedly 3 initialPlan
|
||||
@ -170,7 +170,7 @@ generatePlan newProjectDetails planRules = do
|
||||
|]
|
||||
writeToLog "Sending plan to GPT for fixing..."
|
||||
fixedPlan <-
|
||||
queryChatGPTForJSON ((defaultChatGPTParamsForFixing newProjectDetails) {_model = planGptModel}) $
|
||||
queryChatGPTForJSON (fixingChatGPTParams $ planningChatGPTParams newProjectDetails) $
|
||||
chatMessages
|
||||
<> [ ChatMessage {role = Assistant, content = Util.Aeson.encodeToText plan'},
|
||||
ChatMessage
|
||||
@ -190,8 +190,6 @@ generatePlan newProjectDetails planRules = do
|
||||
]
|
||||
return (False, fixedPlan)
|
||||
|
||||
planGptModel = GPT_4
|
||||
|
||||
checkPlanForEntityIssues :: Plan -> [String]
|
||||
checkPlanForEntityIssues plan =
|
||||
checkNumEntities
|
||||
|
@ -18,7 +18,8 @@ import NeatInterpolation (trimming)
|
||||
import Wasp.AI.CodeAgent (CodeAgent, getFile, writeToFile)
|
||||
import Wasp.AI.GenerateNewProject.Common
|
||||
( NewProjectDetails,
|
||||
defaultChatGPTParamsForFixing,
|
||||
codingChatGPTParams,
|
||||
fixingChatGPTParams,
|
||||
queryChatGPTForJSON,
|
||||
)
|
||||
import Wasp.AI.GenerateNewProject.Common.Prompts (appDescriptionBlock)
|
||||
@ -51,7 +52,7 @@ fixWaspFile newProjectDetails waspFilePath plan = do
|
||||
OnlyIfCompileErrors | null compileErrors -> return $ WaspFile {waspFileContent = wfContent}
|
||||
_otherwise ->
|
||||
queryChatGPTForJSON
|
||||
(defaultChatGPTParamsForFixing newProjectDetails)
|
||||
(fixingChatGPTParams $ codingChatGPTParams newProjectDetails)
|
||||
[ ChatMessage {role = System, content = Prompts.systemPrompt},
|
||||
ChatMessage {role = User, content = fixWaspFilePrompt wfContent compileErrors}
|
||||
]
|
||||
|
@ -88,7 +88,17 @@ data ChatGPTParams = ChatGPTParams
|
||||
deriving (Show)
|
||||
|
||||
-- TODO: There are some more data models there but for now we went with these core ones.
|
||||
data Model = GPT_3_5_turbo | GPT_3_5_turbo_16k | GPT_4
|
||||
data Model
|
||||
= GPT_3_5_turbo_1106
|
||||
| GPT_3_5_turbo
|
||||
| GPT_3_5_turbo_16k
|
||||
| GPT_3_5_turbo_0613
|
||||
| GPT_3_5_turbo_16k_0613
|
||||
| GPT_4_1106_Preview
|
||||
| GPT_4
|
||||
| GPT_4_32k
|
||||
| GPT_4_0613
|
||||
| GPT_4_32k_0613
|
||||
deriving (Eq, Bounded, Enum)
|
||||
|
||||
instance Show Model where
|
||||
@ -96,9 +106,16 @@ instance Show Model where
|
||||
|
||||
modelOpenAiId :: Model -> String
|
||||
modelOpenAiId = \case
|
||||
GPT_3_5_turbo_1106 -> "gpt-3.5-turbo-1106"
|
||||
GPT_3_5_turbo -> "gpt-3.5-turbo"
|
||||
GPT_3_5_turbo_16k -> "gpt-3.5-turbo-16k"
|
||||
GPT_3_5_turbo_0613 -> "gpt-3.5-turbo-0613"
|
||||
GPT_3_5_turbo_16k_0613 -> "gpt-3.5-turbo-16k-0613"
|
||||
GPT_4_1106_Preview -> "gpt-4-1106-preview"
|
||||
GPT_4 -> "gpt-4"
|
||||
GPT_4_32k -> "gpt-4-32k"
|
||||
GPT_4_0613 -> "gpt-4-0613"
|
||||
GPT_4_32k_0613 -> "gpt-4-32k-0613"
|
||||
|
||||
instance FromJSON Model where
|
||||
parseJSON = Aeson.withText "Model" $ \t ->
|
||||
|
Loading…
Reference in New Issue
Block a user