mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-27 23:59:52 +03:00
Stub out support for Azure OpenAI (#8624)
This PR stubs out support for [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) within the `OpenAiCompletionProvider`. It still requires some additional wiring so that it is accessible, but the necessary hooks should be in place now. Release Notes: - N/A
This commit is contained in:
parent
cbcd011a36
commit
dab886f479
@ -102,8 +102,9 @@ pub struct OpenAiResponseStreamEvent {
|
|||||||
pub usage: Option<OpenAiUsage>,
|
pub usage: Option<OpenAiUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
async fn stream_completion(
|
||||||
api_url: String,
|
api_url: String,
|
||||||
|
kind: OpenAiCompletionProviderKind,
|
||||||
credential: ProviderCredential,
|
credential: ProviderCredential,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
request: Box<dyn CompletionRequest>,
|
request: Box<dyn CompletionRequest>,
|
||||||
@ -117,10 +118,11 @@ pub async fn stream_completion(
|
|||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
|
||||||
let json_data = request.data()?;
|
let json_data = request.data()?;
|
||||||
let mut response = Request::post(format!("{api_url}/chat/completions"))
|
let mut response = Request::post(kind.completions_endpoint_url(&api_url))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header(auth_header_name, auth_header_value)
|
||||||
.body(json_data)?
|
.body(json_data)?
|
||||||
.send_async()
|
.send_async()
|
||||||
.await?;
|
.await?;
|
||||||
@ -194,22 +196,65 @@ pub async fn stream_completion(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum OpenAiCompletionProviderKind {
|
||||||
|
OpenAi,
|
||||||
|
AzureOpenAi {
|
||||||
|
deployment_id: String,
|
||||||
|
api_version: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAiCompletionProviderKind {
|
||||||
|
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
|
||||||
|
fn completions_endpoint_url(&self, api_url: &str) -> String {
|
||||||
|
match self {
|
||||||
|
Self::OpenAi => {
|
||||||
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
format!("{api_url}/chat/completions")
|
||||||
|
}
|
||||||
|
Self::AzureOpenAi {
|
||||||
|
deployment_id,
|
||||||
|
api_version,
|
||||||
|
} => {
|
||||||
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
|
||||||
|
format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
|
||||||
|
fn auth_header(&self, api_key: String) -> (&'static str, String) {
|
||||||
|
match self {
|
||||||
|
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
|
||||||
|
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OpenAiCompletionProvider {
|
pub struct OpenAiCompletionProvider {
|
||||||
api_url: String,
|
api_url: String,
|
||||||
|
kind: OpenAiCompletionProviderKind,
|
||||||
model: OpenAiLanguageModel,
|
model: OpenAiLanguageModel,
|
||||||
credential: Arc<RwLock<ProviderCredential>>,
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompletionProvider {
|
impl OpenAiCompletionProvider {
|
||||||
pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
|
pub async fn new(
|
||||||
|
api_url: String,
|
||||||
|
kind: OpenAiCompletionProviderKind,
|
||||||
|
model_name: String,
|
||||||
|
executor: BackgroundExecutor,
|
||||||
|
) -> Self {
|
||||||
let model = executor
|
let model = executor
|
||||||
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||||
.await;
|
.await;
|
||||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
Self {
|
Self {
|
||||||
api_url,
|
api_url,
|
||||||
|
kind,
|
||||||
model,
|
model,
|
||||||
credential,
|
credential,
|
||||||
executor,
|
executor,
|
||||||
@ -297,6 +342,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
|
|||||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: Box<dyn CompletionRequest>,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
@ -307,7 +353,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
|
|||||||
// At some point in the future we should rectify this.
|
// At some point in the future we should rectify this.
|
||||||
let credential = self.credential.read().clone();
|
let credential = self.credential.read().clone();
|
||||||
let api_url = self.api_url.clone();
|
let api_url = self.api_url.clone();
|
||||||
let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
|
let kind = self.kind.clone();
|
||||||
|
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
|
||||||
async move {
|
async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
@ -322,6 +369,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
|
|||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
Box::new((*self).clone())
|
Box::new((*self).clone())
|
||||||
}
|
}
|
||||||
|
@ -7,11 +7,13 @@ use crate::{
|
|||||||
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
|
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
|
||||||
};
|
};
|
||||||
use ai::prompts::repository_context::PromptCodeSnippet;
|
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
use ai::providers::open_ai::OPEN_AI_API_URL;
|
|
||||||
use ai::{
|
use ai::{
|
||||||
auth::ProviderCredential,
|
auth::ProviderCredential,
|
||||||
completion::{CompletionProvider, CompletionRequest},
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage},
|
providers::open_ai::{
|
||||||
|
OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage,
|
||||||
|
OPEN_AI_API_URL,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
@ -131,6 +133,7 @@ impl AssistantPanel {
|
|||||||
})?;
|
})?;
|
||||||
let completion_provider = OpenAiCompletionProvider::new(
|
let completion_provider = OpenAiCompletionProvider::new(
|
||||||
api_url,
|
api_url,
|
||||||
|
OpenAiCompletionProviderKind::OpenAi,
|
||||||
model_name,
|
model_name,
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
)
|
)
|
||||||
@ -1533,6 +1536,7 @@ impl Conversation {
|
|||||||
api_url
|
api_url
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
|
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
|
||||||
|
OpenAiCompletionProviderKind::OpenAi,
|
||||||
model.full_name().into(),
|
model.full_name().into(),
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user