Allow OpenAI API URL to be configured via assistant.openai_api_url (#7552)

Partially fixes #4321, since Azure OpenAI API can be converted to OpenAI
API.

Release Notes:

- Added `assistant.openai_api_url` setting to allow OpenAI API URL to be
configured.

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Yesterday17 2024-02-13 00:37:27 +08:00 committed by GitHub
parent d959719f3e
commit 9e17018416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 60 additions and 12 deletions

View File

@ -212,6 +212,8 @@
"default_width": 640, "default_width": 640,
// Default height when the assistant is docked to the bottom. // Default height when the assistant is docked to the bottom.
"default_height": 320, "default_height": 320,
// The default OpenAI API endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
// The default OpenAI model to use when starting new conversations. This // The default OpenAI model to use when starting new conversations. This
// setting can take three values: // setting can take three values:
// //

View File

@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent {
} }
pub async fn stream_completion( pub async fn stream_completion(
api_url: String,
credential: ProviderCredential, credential: ProviderCredential,
executor: BackgroundExecutor, executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>, request: Box<dyn CompletionRequest>,
@ -117,7 +118,7 @@ pub async fn stream_completion(
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>(); let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
let json_data = request.data()?; let json_data = request.data()?;
let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions")) let mut response = Request::post(format!("{api_url}/chat/completions"))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header("Authorization", format!("Bearer {}", api_key))
.body(json_data)? .body(json_data)?
@ -195,18 +196,20 @@ pub async fn stream_completion(
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAiCompletionProvider { pub struct OpenAiCompletionProvider {
api_url: String,
model: OpenAiLanguageModel, model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>, credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
} }
impl OpenAiCompletionProvider { impl OpenAiCompletionProvider {
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self { pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
let model = executor let model = executor
.spawn(async move { OpenAiLanguageModel::load(&model_name) }) .spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await; .await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self { Self {
api_url,
model, model,
credential, credential,
executor, executor,
@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
// which is currently model based, due to the language model. // which is currently model based, due to the language model.
// At some point in the future we should rectify this. // At some point in the future we should rectify this.
let credential = self.credential.read().clone(); let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt); let api_url = self.api_url.clone();
let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
async move { async move {
let response = request.await?; let response = request.await?;
let stream = response let stream = response

View File

@ -35,6 +35,7 @@ lazy_static! {
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAiEmbeddingProvider { pub struct OpenAiEmbeddingProvider {
api_url: String,
model: OpenAiLanguageModel, model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>, credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>, pub client: Arc<dyn HttpClient>,
@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage {
} }
impl OpenAiEmbeddingProvider { impl OpenAiEmbeddingProvider {
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self { pub async fn new(
api_url: String,
client: Arc<dyn HttpClient>,
executor: BackgroundExecutor,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider {
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAiEmbeddingProvider { OpenAiEmbeddingProvider {
api_url,
model, model,
credential, credential,
client, client,
@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider {
} }
async fn send_request( async fn send_request(
&self, &self,
api_url: &str,
api_key: &str, api_key: &str,
spans: Vec<&str>, spans: Vec<&str>,
request_timeout: u64, request_timeout: u64,
) -> Result<Response<AsyncBody>> { ) -> Result<Response<AsyncBody>> {
let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings")) let request = Request::post(format!("{api_url}/embeddings"))
.redirect_policy(isahc::config::RedirectPolicy::Follow) .redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout)) .timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4; const MAX_RETRIES: usize = 4;
let api_url = self.api_url.as_str();
let api_key = self.get_api_key()?; let api_key = self.get_api_key()?;
let mut request_number = 0; let mut request_number = 0;
@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
while request_number < MAX_RETRIES { while request_number < MAX_RETRIES {
response = self response = self
.send_request( .send_request(
&api_url,
&api_key, &api_key,
spans.iter().map(|x| &**x).collect(), spans.iter().map(|x| &**x).collect(),
request_timeout, request_timeout,

View File

@ -68,6 +68,7 @@ struct SavedConversation {
messages: Vec<SavedMessage>, messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>, message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String, summary: String,
api_url: Option<String>,
model: OpenAiModel, model: OpenAiModel,
} }

View File

@ -7,6 +7,7 @@ use crate::{
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext, SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
}; };
use ai::prompts::repository_context::PromptCodeSnippet; use ai::prompts::repository_context::PromptCodeSnippet;
use ai::providers::open_ai::OPEN_AI_API_URL;
use ai::{ use ai::{
auth::ProviderCredential, auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest}, completion::{CompletionProvider, CompletionRequest},
@ -121,9 +122,21 @@ impl AssistantPanel {
.await .await
.log_err() .log_err()
.unwrap_or_default(); .unwrap_or_default();
// Defaulting currently to GPT4, allow for this to be set via config. let (api_url, model_name) = cx
let completion_provider = .update(|cx| {
OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) let settings = AssistantSettings::get_global(cx);
(
settings.openai_api_url.clone(),
settings.default_open_ai_model.full_name().to_string(),
)
})
.log_err()
.unwrap();
let completion_provider = OpenAiCompletionProvider::new(
api_url,
model_name,
cx.background_executor().clone(),
)
.await; .await;
// TODO: deserialize state. // TODO: deserialize state.
@ -1407,6 +1420,7 @@ struct Conversation {
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
model: OpenAiModel, model: OpenAiModel,
api_url: Option<String>,
token_count: Option<usize>, token_count: Option<usize>,
max_token_count: usize, max_token_count: usize,
pending_token_count: Task<Option<()>>, pending_token_count: Task<Option<()>>,
@ -1441,6 +1455,7 @@ impl Conversation {
let settings = AssistantSettings::get_global(cx); let settings = AssistantSettings::get_global(cx);
let model = settings.default_open_ai_model.clone(); let model = settings.default_open_ai_model.clone();
let api_url = settings.openai_api_url.clone();
let mut this = Self { let mut this = Self {
id: Some(Uuid::new_v4().to_string()), id: Some(Uuid::new_v4().to_string()),
@ -1454,6 +1469,7 @@ impl Conversation {
token_count: None, token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None), pending_token_count: Task::ready(None),
api_url: Some(api_url),
model: model.clone(), model: model.clone(),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),
@ -1499,6 +1515,7 @@ impl Conversation {
.map(|summary| summary.text.clone()) .map(|summary| summary.text.clone())
.unwrap_or_default(), .unwrap_or_default(),
model: self.model.clone(), model: self.model.clone(),
api_url: self.api_url.clone(),
} }
} }
@ -1513,8 +1530,12 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()), None => Some(Uuid::new_v4().to_string()),
}; };
let model = saved_conversation.model; let model = saved_conversation.model;
let api_url = saved_conversation.api_url;
let completion_provider: Arc<dyn CompletionProvider> = Arc::new( let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
OpenAiCompletionProvider::new( OpenAiCompletionProvider::new(
api_url
.clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
model.full_name().into(), model.full_name().into(),
cx.background_executor().clone(), cx.background_executor().clone(),
) )
@ -1567,6 +1588,7 @@ impl Conversation {
token_count: None, token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None), pending_token_count: Task::ready(None),
api_url,
model, model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),

View File

@ -55,6 +55,7 @@ pub struct AssistantSettings {
pub default_width: Pixels, pub default_width: Pixels,
pub default_height: Pixels, pub default_height: Pixels,
pub default_open_ai_model: OpenAiModel, pub default_open_ai_model: OpenAiModel,
pub openai_api_url: String,
} }
/// Assistant panel settings /// Assistant panel settings
@ -80,6 +81,10 @@ pub struct AssistantSettingsContent {
/// ///
/// Default: gpt-4-1106-preview /// Default: gpt-4-1106-preview
pub default_open_ai_model: Option<OpenAiModel>, pub default_open_ai_model: Option<OpenAiModel>,
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
pub openai_api_url: Option<String>,
} }
impl Settings for AssistantSettings { impl Settings for AssistantSettings {

View File

@ -8,7 +8,7 @@ mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings; use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider}; use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAiEmbeddingProvider; use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase; use db::VectorDatabase;
@ -91,8 +91,13 @@ pub fn init(
.detach(); .detach();
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
let embedding_provider = let embedding_provider = OpenAiEmbeddingProvider::new(
OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
OPEN_AI_API_URL.to_string(),
http_client,
cx.background_executor().clone(),
)
.await;
let semantic_index = SemanticIndex::new( let semantic_index = SemanticIndex::new(
fs, fs,
db_file_path, db_file_path,