mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-28 22:24:47 +03:00
move OpenAICompletionProvider to providers location
This commit is contained in:
parent
3712794e56
commit
05ae978cb7
@ -1,177 +1,7 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::executor::Background;
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
mut request: OpenAIRequest,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
request.stream = true;
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = serde_json::to_string(&request)?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::providers::open_ai::completion::OpenAIRequest;
|
||||
|
||||
pub trait CompletionProvider {
|
||||
fn complete(
|
||||
@ -179,36 +9,3 @@ pub trait CompletionProvider {
|
||||
prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
}
|
||||
|
||||
pub struct OpenAICompletionProvider {
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
||||
Self { api_key, executor }
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
209
crates/ai/src/providers/open_ai/completion.rs
Normal file
209
crates/ai/src/providers/open_ai/completion.rs
Normal file
@ -0,0 +1,209 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::executor::Background;
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::completion::CompletionProvider;
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
mut request: OpenAIRequest,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
request.stream = true;
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = serde_json::to_string(&request)?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAICompletionProvider {
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
||||
Self { api_key, executor }
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
@ -1,2 +1,4 @@
|
||||
pub mod completion;
|
||||
pub mod model;
|
||||
pub use completion::*;
|
||||
pub use model::OpenAILanguageModel;
|
||||
|
@ -4,7 +4,7 @@ mod codegen;
|
||||
mod prompts;
|
||||
mod streaming_diff;
|
||||
|
||||
use ai::completion::Role;
|
||||
use ai::providers::open_ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::OpenAIModel;
|
||||
|
@ -5,12 +5,12 @@ use crate::{
|
||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage,
|
||||
};
|
||||
use ai::{
|
||||
completion::{
|
||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||
},
|
||||
prompts::repository_context::PromptCodeSnippet,
|
||||
|
||||
use ai::providers::open_ai::{
|
||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||
};
|
||||
|
||||
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
||||
use ai::completion::CompletionProvider;
|
||||
use ai::providers::open_ai::OpenAIRequest;
|
||||
use anyhow::Result;
|
||||
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
|
Loading…
Reference in New Issue
Block a user