Stub out support for Azure OpenAI (#8624)

This PR stubs out support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
within the `OpenAiCompletionProvider`.

It still requires some additional wiring so that it is accessible, but
the necessary hooks should be in place now.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-02-29 13:02:08 -05:00 committed by GitHub
parent cbcd011a36
commit dab886f479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 7 deletions

View File

@ -102,8 +102,9 @@ pub struct OpenAiResponseStreamEvent {
pub usage: Option<OpenAiUsage>, pub usage: Option<OpenAiUsage>,
} }
pub async fn stream_completion( async fn stream_completion(
api_url: String, api_url: String,
kind: OpenAiCompletionProviderKind,
credential: ProviderCredential, credential: ProviderCredential,
executor: BackgroundExecutor, executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>, request: Box<dyn CompletionRequest>,
@ -117,10 +118,11 @@ pub async fn stream_completion(
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>(); let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
let json_data = request.data()?; let json_data = request.data()?;
let mut response = Request::post(format!("{api_url}/chat/completions")) let mut response = Request::post(kind.completions_endpoint_url(&api_url))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header(auth_header_name, auth_header_value)
.body(json_data)? .body(json_data)?
.send_async() .send_async()
.await?; .await?;
@ -194,22 +196,65 @@ pub async fn stream_completion(
} }
} }
#[derive(Clone)]
pub enum OpenAiCompletionProviderKind {
OpenAi,
AzureOpenAi {
deployment_id: String,
api_version: String,
},
}
impl OpenAiCompletionProviderKind {
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
fn completions_endpoint_url(&self, api_url: &str) -> String {
match self {
Self::OpenAi => {
// https://platform.openai.com/docs/api-reference/chat/create
format!("{api_url}/chat/completions")
}
Self::AzureOpenAi {
deployment_id,
api_version,
} => {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
}
}
}
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
fn auth_header(&self, api_key: String) -> (&'static str, String) {
match self {
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAiCompletionProvider { pub struct OpenAiCompletionProvider {
api_url: String, api_url: String,
kind: OpenAiCompletionProviderKind,
model: OpenAiLanguageModel, model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>, credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
} }
impl OpenAiCompletionProvider { impl OpenAiCompletionProvider {
pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self { pub async fn new(
api_url: String,
kind: OpenAiCompletionProviderKind,
model_name: String,
executor: BackgroundExecutor,
) -> Self {
let model = executor let model = executor
.spawn(async move { OpenAiLanguageModel::load(&model_name) }) .spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await; .await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self { Self {
api_url, api_url,
kind,
model, model,
credential, credential,
executor, executor,
@ -297,6 +342,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone()); let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model model
} }
fn complete( fn complete(
&self, &self,
prompt: Box<dyn CompletionRequest>, prompt: Box<dyn CompletionRequest>,
@ -307,7 +353,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
// At some point in the future we should rectify this. // At some point in the future we should rectify this.
let credential = self.credential.read().clone(); let credential = self.credential.read().clone();
let api_url = self.api_url.clone(); let api_url = self.api_url.clone();
let request = stream_completion(api_url, credential, self.executor.clone(), prompt); let kind = self.kind.clone();
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
async move { async move {
let response = request.await?; let response = request.await?;
let stream = response let stream = response
@ -322,6 +369,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
} }
.boxed() .boxed()
} }
fn box_clone(&self) -> Box<dyn CompletionProvider> { fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone()) Box::new((*self).clone())
} }

View File

@ -7,11 +7,13 @@ use crate::{
SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext, SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
}; };
use ai::prompts::repository_context::PromptCodeSnippet; use ai::prompts::repository_context::PromptCodeSnippet;
use ai::providers::open_ai::OPEN_AI_API_URL;
use ai::{ use ai::{
auth::ProviderCredential, auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest}, completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage}, providers::open_ai::{
OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage,
OPEN_AI_API_URL,
},
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
@ -131,6 +133,7 @@ impl AssistantPanel {
})?; })?;
let completion_provider = OpenAiCompletionProvider::new( let completion_provider = OpenAiCompletionProvider::new(
api_url, api_url,
OpenAiCompletionProviderKind::OpenAi,
model_name, model_name,
cx.background_executor().clone(), cx.background_executor().clone(),
) )
@ -1533,6 +1536,7 @@ impl Conversation {
api_url api_url
.clone() .clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()), .unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
OpenAiCompletionProviderKind::OpenAi,
model.full_name().into(), model.full_name().into(),
cx.background_executor().clone(), cx.background_executor().clone(),
) )