diff --git a/assets/settings/default.json b/assets/settings/default.json index 82e848dba0..7afa7f0d3b 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -212,6 +212,8 @@ "default_width": 640, // Default height when the assistant is docked to the bottom. "default_height": 320, + // The default OpenAI API endpoint to use when starting new conversations. + "openai_api_url": "https://api.openai.com/v1", // The default OpenAI model to use when starting new conversations. This // setting can take three values: // diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 4bdb94d79b..f3c7ebbdbc 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent { } pub async fn stream_completion( + api_url: String, credential: ProviderCredential, executor: BackgroundExecutor, request: Box, @@ -117,7 +118,7 @@ pub async fn stream_completion( let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = request.data()?; - let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions")) + let mut response = Request::post(format!("{api_url}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body(json_data)? @@ -195,18 +196,20 @@ pub async fn stream_completion( #[derive(Clone)] pub struct OpenAiCompletionProvider { + api_url: String, model: OpenAiLanguageModel, credential: Arc>, executor: BackgroundExecutor, } impl OpenAiCompletionProvider { - pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self { + pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self { let model = executor .spawn(async move { OpenAiLanguageModel::load(&model_name) }) .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { + api_url, model, credential, executor, @@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider { // which is currently model based, due to the language model. // At some point in the future we should rectify this. let credential = self.credential.read().clone(); - let request = stream_completion(credential, self.executor.clone(), prompt); + let api_url = self.api_url.clone(); + let request = stream_completion(api_url, credential, self.executor.clone(), prompt); async move { let response = request.await?; let stream = response diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 29ee8fac9b..588861a972 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -35,6 +35,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAiEmbeddingProvider { + api_url: String, model: OpenAiLanguageModel, credential: Arc>, pub client: Arc, @@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage { } impl OpenAiEmbeddingProvider { - pub async fn new(client: Arc, executor: BackgroundExecutor) -> Self { + pub async fn new( + api_url: String, + client: Arc, + executor: BackgroundExecutor, + ) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); @@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider { let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAiEmbeddingProvider { + api_url, model, credential, client, @@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider { } async fn send_request( &self, + api_url: &str, api_key: &str, spans: Vec<&str>, request_timeout: u64, ) -> Result> { - let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings")) + let request = Request::post(format!("{api_url}/embeddings")) .redirect_policy(isahc::config::RedirectPolicy::Follow) .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") @@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; + let api_url = self.api_url.as_str(); let api_key = self.get_api_key()?; let mut request_number = 0; @@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider { while request_number < MAX_RETRIES { response = self .send_request( + &api_url, &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index d86d889aff..d262ffd57d 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -68,6 +68,7 @@ struct SavedConversation { messages: Vec, message_metadata: HashMap, summary: String, + api_url: Option, model: OpenAiModel, } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 3bd928961d..4e861c9d3e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -7,6 +7,7 @@ use crate::{ SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext, }; use ai::prompts::repository_context::PromptCodeSnippet; +use ai::providers::open_ai::OPEN_AI_API_URL; use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, @@ -121,10 +122,22 @@ impl AssistantPanel { .await .log_err() .unwrap_or_default(); - // Defaulting currently to GPT4, allow for this to be set via config. - let completion_provider = - OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) - .await; + let (api_url, model_name) = cx + .update(|cx| { + let settings = AssistantSettings::get_global(cx); + ( + settings.openai_api_url.clone(), + settings.default_open_ai_model.full_name().to_string(), + ) + }) + .log_err() + .unwrap(); + let completion_provider = OpenAiCompletionProvider::new( + api_url, + model_name, + cx.background_executor().clone(), + ) + .await; // TODO: deserialize state. let workspace_handle = workspace.clone(); @@ -1407,6 +1420,7 @@ struct Conversation { completion_count: usize, pending_completions: Vec, model: OpenAiModel, + api_url: Option, token_count: Option, max_token_count: usize, pending_token_count: Task>, @@ -1441,6 +1455,7 @@ impl Conversation { let settings = AssistantSettings::get_global(cx); let model = settings.default_open_ai_model.clone(); + let api_url = settings.openai_api_url.clone(); let mut this = Self { id: Some(Uuid::new_v4().to_string()), @@ -1454,6 +1469,7 @@ impl Conversation { token_count: None, max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), + api_url: Some(api_url), model: model.clone(), _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), @@ -1499,6 +1515,7 @@ impl Conversation { .map(|summary| summary.text.clone()) .unwrap_or_default(), model: self.model.clone(), + api_url: self.api_url.clone(), } } @@ -1513,8 +1530,12 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let api_url = saved_conversation.api_url; let completion_provider: Arc = Arc::new( OpenAiCompletionProvider::new( + api_url + .clone() + .unwrap_or_else(|| OPEN_AI_API_URL.to_string()), model.full_name().into(), cx.background_executor().clone(), ) @@ -1567,6 +1588,7 @@ impl Conversation { token_count: None, max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), + api_url, model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 4b37a1b2f6..16c2d452c2 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -55,6 +55,7 @@ pub struct AssistantSettings { pub default_width: Pixels, pub default_height: Pixels, pub default_open_ai_model: OpenAiModel, + pub openai_api_url: String, } /// Assistant panel settings @@ -80,6 +81,10 @@ pub struct AssistantSettingsContent { /// /// Default: gpt-4-1106-preview pub default_open_ai_model: Option, + /// OpenAI API base URL to use when starting new conversations. + /// + /// Default: https://api.openai.com/v1 + pub openai_api_url: Option, } impl Settings for AssistantSettings { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 324e03381e..df277fbc9b 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -8,7 +8,7 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::open_ai::OpenAiEmbeddingProvider; +use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL}; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -91,8 +91,13 @@ pub fn init( .detach(); cx.spawn(move |cx| async move { - let embedding_provider = - OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; + let embedding_provider = OpenAiEmbeddingProvider::new( + // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not + OPEN_AI_API_URL.to_string(), + http_client, + cx.background_executor().clone(), + ) + .await; let semantic_index = SemanticIndex::new( fs, db_file_path,