mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Allow OpenAI API URL to be configured via assistant.openai_api_url
(#7552)
Partially fixes #4321, since Azure OpenAI API can be converted to OpenAI API. Release Notes: - Added `assistant.openai_api_url` setting to allow OpenAI API URL to be configured. --------- Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
d959719f3e
commit
9e17018416
@ -212,6 +212,8 @@
|
|||||||
"default_width": 640,
|
"default_width": 640,
|
||||||
// Default height when the assistant is docked to the bottom.
|
// Default height when the assistant is docked to the bottom.
|
||||||
"default_height": 320,
|
"default_height": 320,
|
||||||
|
// The default OpenAI API endpoint to use when starting new conversations.
|
||||||
|
"openai_api_url": "https://api.openai.com/v1",
|
||||||
// The default OpenAI model to use when starting new conversations. This
|
// The default OpenAI model to use when starting new conversations. This
|
||||||
// setting can take three values:
|
// setting can take three values:
|
||||||
//
|
//
|
||||||
|
@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
|
api_url: String,
|
||||||
credential: ProviderCredential,
|
credential: ProviderCredential,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
request: Box<dyn CompletionRequest>,
|
request: Box<dyn CompletionRequest>,
|
||||||
@ -117,7 +118,7 @@ pub async fn stream_completion(
|
|||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||||
|
|
||||||
let json_data = request.data()?;
|
let json_data = request.data()?;
|
||||||
let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
|
let mut response = Request::post(format!("{api_url}/chat/completions"))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
.body(json_data)?
|
.body(json_data)?
|
||||||
@ -195,18 +196,20 @@ pub async fn stream_completion(
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAiCompletionProvider {
|
pub struct OpenAiCompletionProvider {
|
||||||
|
api_url: String,
|
||||||
model: OpenAiLanguageModel,
|
model: OpenAiLanguageModel,
|
||||||
credential: Arc<RwLock<ProviderCredential>>,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompletionProvider {
|
impl OpenAiCompletionProvider {
|
||||||
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
|
pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
|
||||||
let model = executor
|
let model = executor
|
||||||
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||||
.await;
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
Self {
|
Self {
|
||||||
|
api_url,
|
||||||
model,
|
model,
|
||||||
credential,
|
credential,
|
||||||
executor,
|
executor,
|
||||||
@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
|
|||||||
// which is currently model based, due to the language model.
|
// which is currently model based, due to the language model.
|
||||||
// At some point in the future we should rectify this.
|
// At some point in the future we should rectify this.
|
||||||
let credential = self.credential.read().clone();
|
let credential = self.credential.read().clone();
|
||||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
let api_url = self.api_url.clone();
|
||||||
|
let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
|
||||||
async move {
|
async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
|
@ -35,6 +35,7 @@ lazy_static! {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAiEmbeddingProvider {
|
pub struct OpenAiEmbeddingProvider {
|
||||||
|
api_url: String,
|
||||||
model: OpenAiLanguageModel,
|
model: OpenAiLanguageModel,
|
||||||
credential: Arc<RwLock<ProviderCredential>>,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
pub client: Arc<dyn HttpClient>,
|
pub client: Arc<dyn HttpClient>,
|
||||||
@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiEmbeddingProvider {
|
impl OpenAiEmbeddingProvider {
|
||||||
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
pub async fn new(
|
||||||
|
api_url: String,
|
||||||
|
client: Arc<dyn HttpClient>,
|
||||||
|
executor: BackgroundExecutor,
|
||||||
|
) -> Self {
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider {
|
|||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
OpenAiEmbeddingProvider {
|
OpenAiEmbeddingProvider {
|
||||||
|
api_url,
|
||||||
model,
|
model,
|
||||||
credential,
|
credential,
|
||||||
client,
|
client,
|
||||||
@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider {
|
|||||||
}
|
}
|
||||||
async fn send_request(
|
async fn send_request(
|
||||||
&self,
|
&self,
|
||||||
|
api_url: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
spans: Vec<&str>,
|
spans: Vec<&str>,
|
||||||
request_timeout: u64,
|
request_timeout: u64,
|
||||||
) -> Result<Response<AsyncBody>> {
|
) -> Result<Response<AsyncBody>> {
|
||||||
let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings"))
|
let request = Request::post(format!("{api_url}/embeddings"))
|
||||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||||
.timeout(Duration::from_secs(request_timeout))
|
.timeout(Duration::from_secs(request_timeout))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
|||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
let api_url = self.api_url.as_str();
|
||||||
let api_key = self.get_api_key()?;
|
let api_key = self.get_api_key()?;
|
||||||
|
|
||||||
let mut request_number = 0;
|
let mut request_number = 0;
|
||||||
@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
|||||||
while request_number < MAX_RETRIES {
|
while request_number < MAX_RETRIES {
|
||||||
response = self
|
response = self
|
||||||
.send_request(
|
.send_request(
|
||||||
|
&api_url,
|
||||||
&api_key,
|
&api_key,
|
||||||
spans.iter().map(|x| &**x).collect(),
|
spans.iter().map(|x| &**x).collect(),
|
||||||
request_timeout,
|
request_timeout,
|
||||||
|
@ -68,6 +68,7 @@ struct SavedConversation {
|
|||||||
messages: Vec<SavedMessage>,
|
messages: Vec<SavedMessage>,
|
||||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||||
summary: String,
|
summary: String,
|
||||||
|
api_url: Option<String>,
|
||||||
model: OpenAiModel,
|
model: OpenAiModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ use crate::{
|
|||||||
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
|
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
|
||||||
};
|
};
|
||||||
use ai::prompts::repository_context::PromptCodeSnippet;
|
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
|
use ai::providers::open_ai::OPEN_AI_API_URL;
|
||||||
use ai::{
|
use ai::{
|
||||||
auth::ProviderCredential,
|
auth::ProviderCredential,
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
@ -121,9 +122,21 @@ impl AssistantPanel {
|
|||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
let (api_url, model_name) = cx
|
||||||
let completion_provider =
|
.update(|cx| {
|
||||||
OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
|
let settings = AssistantSettings::get_global(cx);
|
||||||
|
(
|
||||||
|
settings.openai_api_url.clone(),
|
||||||
|
settings.default_open_ai_model.full_name().to_string(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.log_err()
|
||||||
|
.unwrap();
|
||||||
|
let completion_provider = OpenAiCompletionProvider::new(
|
||||||
|
api_url,
|
||||||
|
model_name,
|
||||||
|
cx.background_executor().clone(),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// TODO: deserialize state.
|
// TODO: deserialize state.
|
||||||
@ -1407,6 +1420,7 @@ struct Conversation {
|
|||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
model: OpenAiModel,
|
model: OpenAiModel,
|
||||||
|
api_url: Option<String>,
|
||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
max_token_count: usize,
|
max_token_count: usize,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
@ -1441,6 +1455,7 @@ impl Conversation {
|
|||||||
|
|
||||||
let settings = AssistantSettings::get_global(cx);
|
let settings = AssistantSettings::get_global(cx);
|
||||||
let model = settings.default_open_ai_model.clone();
|
let model = settings.default_open_ai_model.clone();
|
||||||
|
let api_url = settings.openai_api_url.clone();
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
id: Some(Uuid::new_v4().to_string()),
|
id: Some(Uuid::new_v4().to_string()),
|
||||||
@ -1454,6 +1469,7 @@ impl Conversation {
|
|||||||
token_count: None,
|
token_count: None,
|
||||||
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
||||||
pending_token_count: Task::ready(None),
|
pending_token_count: Task::ready(None),
|
||||||
|
api_url: Some(api_url),
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
@ -1499,6 +1515,7 @@ impl Conversation {
|
|||||||
.map(|summary| summary.text.clone())
|
.map(|summary| summary.text.clone())
|
||||||
.unwrap_or_default(),
|
.unwrap_or_default(),
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
|
api_url: self.api_url.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1513,8 +1530,12 @@ impl Conversation {
|
|||||||
None => Some(Uuid::new_v4().to_string()),
|
None => Some(Uuid::new_v4().to_string()),
|
||||||
};
|
};
|
||||||
let model = saved_conversation.model;
|
let model = saved_conversation.model;
|
||||||
|
let api_url = saved_conversation.api_url;
|
||||||
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
|
||||||
OpenAiCompletionProvider::new(
|
OpenAiCompletionProvider::new(
|
||||||
|
api_url
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
|
||||||
model.full_name().into(),
|
model.full_name().into(),
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
)
|
)
|
||||||
@ -1567,6 +1588,7 @@ impl Conversation {
|
|||||||
token_count: None,
|
token_count: None,
|
||||||
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
||||||
pending_token_count: Task::ready(None),
|
pending_token_count: Task::ready(None),
|
||||||
|
api_url,
|
||||||
model,
|
model,
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
|
@ -55,6 +55,7 @@ pub struct AssistantSettings {
|
|||||||
pub default_width: Pixels,
|
pub default_width: Pixels,
|
||||||
pub default_height: Pixels,
|
pub default_height: Pixels,
|
||||||
pub default_open_ai_model: OpenAiModel,
|
pub default_open_ai_model: OpenAiModel,
|
||||||
|
pub openai_api_url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Assistant panel settings
|
/// Assistant panel settings
|
||||||
@ -80,6 +81,10 @@ pub struct AssistantSettingsContent {
|
|||||||
///
|
///
|
||||||
/// Default: gpt-4-1106-preview
|
/// Default: gpt-4-1106-preview
|
||||||
pub default_open_ai_model: Option<OpenAiModel>,
|
pub default_open_ai_model: Option<OpenAiModel>,
|
||||||
|
/// OpenAI API base URL to use when starting new conversations.
|
||||||
|
///
|
||||||
|
/// Default: https://api.openai.com/v1
|
||||||
|
pub openai_api_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Settings for AssistantSettings {
|
impl Settings for AssistantSettings {
|
||||||
|
@ -8,7 +8,7 @@ mod semantic_index_tests;
|
|||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
use ai::providers::open_ai::OpenAiEmbeddingProvider;
|
use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
@ -91,8 +91,13 @@ pub fn init(
|
|||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
cx.spawn(move |cx| async move {
|
cx.spawn(move |cx| async move {
|
||||||
let embedding_provider =
|
let embedding_provider = OpenAiEmbeddingProvider::new(
|
||||||
OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
|
// TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
|
||||||
|
OPEN_AI_API_URL.to_string(),
|
||||||
|
http_client,
|
||||||
|
cx.background_executor().clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user