diff --git a/api/system.go b/api/system.go index 18c397ce..927d742b 100644 --- a/api/system.go +++ b/api/system.go @@ -19,4 +19,6 @@ type SystemStatus struct { // Customized server profile, including server name and external url. CustomizedProfile CustomizedProfile `json:"customizedProfile"` StorageServiceID int `json:"storageServiceId"` + // OpenAI API Host + OpenAIAPIHost string `json:"openAIApiHost"` } diff --git a/api/system_setting.go b/api/system_setting.go index 5069fd55..2b271993 100644 --- a/api/system_setting.go +++ b/api/system_setting.go @@ -29,6 +29,8 @@ const ( SystemSettingStorageServiceIDName SystemSettingName = "storageServiceId" // SystemSettingOpenAIAPIKeyName is the key type of OpenAI API key. SystemSettingOpenAIAPIKeyName SystemSettingName = "openAIApiKey" + // SystemSettingOpenAIAPIHost is the key type of OpenAI API path. + SystemSettingOpenAIAPIHost SystemSettingName = "openAIApiHost" ) // CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item. @@ -67,6 +69,8 @@ func (key SystemSettingName) String() string { return "storageServiceId" case SystemSettingOpenAIAPIKeyName: return "openAIApiKey" + case SystemSettingOpenAIAPIHost: + return "openAIApiHost" } return "" } @@ -171,6 +175,12 @@ func (upsert SystemSettingUpsert) Validate() error { if err != nil { return fmt.Errorf("failed to unmarshal system setting openai api key value") } + } else if upsert.Name == SystemSettingOpenAIAPIHost { + value := "" + err := json.Unmarshal([]byte(upsert.Value), &value) + if err != nil { + return fmt.Errorf("failed to unmarshal system setting openai api host value") + } } else { return fmt.Errorf("invalid system setting name") } diff --git a/plugin/openai/chat_completion.go b/plugin/openai/chat_completion.go index 3b5657e0..b83147a9 100644 --- a/plugin/openai/chat_completion.go +++ b/plugin/openai/chat_completion.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net/http" + "net/url" "strings" ) @@ -23,12 +24,20 @@ type ChatCompletionResponse struct { Choices []ChatCompletionChoice `json:"choices"` } -func PostChatCompletion(prompt string, apiKey string) (string, error) { +func PostChatCompletion(prompt string, apiKey string, apiHost string) (string, error) { requestBody := strings.NewReader(`{ "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "` + prompt + `"}] }`) - req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", requestBody) + if apiHost == "" { + apiHost = "https://api.openai.com" + } + url, err := url.JoinPath(apiHost, "/v1/chat/completions") + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", url, requestBody) if err != nil { return "", err } diff --git a/plugin/openai/text_completion.go b/plugin/openai/text_completion.go index 26cf2b64..25ce7dbf 100644 --- a/plugin/openai/text_completion.go +++ b/plugin/openai/text_completion.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net/http" + "net/url" "strings" ) @@ -18,7 +19,7 @@ type TextCompletionResponse struct { Choices []TextCompletionChoice `json:"choices"` } -func PostTextCompletion(prompt string, apiKey string) (string, error) { +func PostTextCompletion(prompt string, apiKey string, apiHost string) (string, error) { requestBody := strings.NewReader(`{ "prompt": "` + prompt + `", "temperature": 0.5, @@ -26,7 +27,15 @@ func PostTextCompletion(prompt string, apiKey string) (string, error) { "n": 1, "stop": "." }`) - req, err := http.NewRequest("POST", "https://api.openai.com/v1/completions", requestBody) + if apiHost == "" { + apiHost = "https://api.openai.com" + } + url, err := url.JoinPath(apiHost, "/v1/chat/completions") + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", url, requestBody) if err != nil { return "", err } diff --git a/server/openai.go b/server/openai.go index 0761cc34..7be7129b 100644 --- a/server/openai.go +++ b/server/openai.go @@ -20,6 +20,13 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err) } + openAIApiHostSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: api.SystemSettingOpenAIAPIHost, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api host").SetInternal(err) + } + openAIApiKey := "" if openAIApiKeySetting != nil { err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey) @@ -31,6 +38,14 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") } + openAIApiHost := "" + if openAIApiHostSetting != nil { + err = json.Unmarshal([]byte(openAIApiHostSetting.Value), &openAIApiHost) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err) + } + } + completionRequest := api.OpenAICompletionRequest{} if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err) @@ -39,7 +54,7 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") } - result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey) + result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIApiKey, openAIApiHost) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err) } @@ -56,6 +71,13 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api key").SetInternal(err) } + openAIApiHostSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ + Name: api.SystemSettingOpenAIAPIHost, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai api host").SetInternal(err) + } + openAIApiKey := "" if openAIApiKeySetting != nil { err = json.Unmarshal([]byte(openAIApiKeySetting.Value), &openAIApiKey) @@ -67,6 +89,14 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") } + openAIApiHost := "" + if openAIApiHostSetting != nil { + err = json.Unmarshal([]byte(openAIApiHostSetting.Value), &openAIApiHost) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting value").SetInternal(err) + } + } + textCompletion := api.OpenAICompletionRequest{} if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err) @@ -75,7 +105,7 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") } - result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey) + result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIApiKey, openAIApiHost) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err) } diff --git a/server/system.go b/server/system.go index a2b2a717..f87a1557 100644 --- a/server/system.go +++ b/server/system.go @@ -52,6 +52,7 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { ExternalURL: "", }, StorageServiceID: 0, + OpenAIAPIHost: "", } systemSettingList, err := s.Store.FindSystemSettingList(ctx, &api.SystemSettingFind{}) @@ -100,6 +101,8 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { } } else if systemSetting.Name == api.SystemSettingStorageServiceIDName { systemStatus.StorageServiceID = int(value.(float64)) + } else if systemSetting.Name == api.SystemSettingOpenAIAPIHost { + systemStatus.OpenAIAPIHost = value.(string) } } diff --git a/web/src/components/Settings/SystemSection.tsx b/web/src/components/Settings/SystemSection.tsx index 25479755..599f97dd 100644 --- a/web/src/components/Settings/SystemSection.tsx +++ b/web/src/components/Settings/SystemSection.tsx @@ -14,6 +14,7 @@ interface State { allowSignUp: boolean; disablePublicMemos: boolean; openAIApiKey: string; + openAIApiHost: string; additionalStyle: string; additionalScript: string; } @@ -36,6 +37,7 @@ const SystemSection = () => { allowSignUp: systemStatus.allowSignUp, additionalStyle: systemStatus.additionalStyle, openAIApiKey: "", + openAIApiHost: systemStatus.openAIApiHost, additionalScript: systemStatus.additionalScript, disablePublicMemos: systemStatus.disablePublicMemos, }); @@ -52,6 +54,7 @@ const SystemSection = () => { allowSignUp: systemStatus.allowSignUp, additionalStyle: systemStatus.additionalStyle, openAIApiKey: "", + openAIApiHost: systemStatus.openAIApiHost, additionalScript: systemStatus.additionalScript, disablePublicMemos: systemStatus.disablePublicMemos, }); @@ -103,6 +106,26 @@ const SystemSection = () => { toastHelper.success("OpenAI Api Key updated"); }; + const handleOpenAIApiHostChanged = (value: string) => { + setState({ + ...state, + openAIApiHost: value, + }); + }; + + const handleSaveOpenAIApiHost = async () => { + try { + await api.upsertSystemSetting({ + name: "openAIApiHost", + value: JSON.stringify(state.openAIApiHost), + }); + } catch (error) { + console.error(error); + return; + } + toastHelper.success("OpenAI Api Host updated"); + }; + const handleAdditionalStyleChanged = (value: string) => { setState({ ...state, @@ -195,6 +218,20 @@ const SystemSection = () => { value={state.openAIApiKey} onChange={(event) => handleOpenAIApiKeyChanged(event.target.value)} /> +
+ OpenAI API Host + +
+ handleOpenAIApiHostChanged(event.target.value)} + />
{t("setting.system-section.additional-style")} diff --git a/web/src/store/module/global.ts b/web/src/store/module/global.ts index b83fe0a1..62aa965c 100644 --- a/web/src/store/module/global.ts +++ b/web/src/store/module/global.ts @@ -22,6 +22,7 @@ export const initialGlobalState = async () => { appearance: "system", externalUrl: "", }, + openAIApiHost: "", } as SystemStatus, }; diff --git a/web/src/types/modules/system.d.ts b/web/src/types/modules/system.d.ts index 18e99d8f..bc7b602b 100644 --- a/web/src/types/modules/system.d.ts +++ b/web/src/types/modules/system.d.ts @@ -23,6 +23,7 @@ interface SystemStatus { additionalScript: string; customizedProfile: CustomizedProfile; storageServiceId: number; + openAIApiHost: string; } interface SystemSetting {